-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathrustgen.py
146 lines (127 loc) · 4.68 KB
/
rustgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Rust trap class generation
"""
import functools
import typing
import inflection
from misc.codegen.lib import rust, schema
from misc.codegen.loaders import schemaloader
def _get_type(t: str) -> str:
match t:
case None: # None means a predicate
return "bool"
case "string":
return "String"
case "int":
return "usize"
case _ if t[0].isupper():
return f"trap::Label<{t}>"
case "boolean":
assert False, "boolean unsupported"
case _:
return t
def _get_table_name(cls: schema.Class, p: schema.Property) -> str:
if p.is_single:
return inflection.tableize(cls.name)
overridden_table_name = p.pragmas.get("ql_db_table_name")
if overridden_table_name:
return overridden_table_name
table_name = f"{cls.name}_{p.name}"
if p.is_predicate:
return inflection.underscore(table_name)
else:
return inflection.tableize(table_name)
def _get_field(cls: schema.Class, p: schema.Property) -> rust.Field:
args = dict(
field_name=rust.avoid_keywords(p.name),
base_type=_get_type(p.type),
is_optional=p.is_optional,
is_repeated=p.is_repeated,
is_predicate=p.is_predicate,
is_unordered=p.is_unordered,
table_name=_get_table_name(cls, p),
)
args.update(rust.get_field_override(p.name))
return rust.Field(**args)
def _get_properties(
cls: schema.Class, lookup: dict[str, schema.ClassBase],
) -> typing.Iterable[tuple[schema.Class, schema.Property]]:
for b in cls.bases:
yield from _get_properties(lookup[b], lookup)
for p in cls.properties:
yield cls, p
def _get_ancestors(
cls: schema.Class, lookup: dict[str, schema.ClassBase]
) -> typing.Iterable[schema.Class]:
for b in cls.bases:
base = lookup[b]
if not base.imported:
base = typing.cast(schema.Class, base)
yield base
yield from _get_ancestors(base, lookup)
class Processor:
def __init__(self, data: schema.Schema):
self._classmap = data.classes
def _get_class(self, name: str) -> rust.Class:
cls = typing.cast(schema.Class, self._classmap[name])
properties = [
(c, p)
for c, p in _get_properties(cls, self._classmap)
if "rust_skip" not in p.pragmas and not p.synth
]
fields = []
detached_fields = []
for c, p in properties:
if "rust_detach" in p.pragmas:
# only generate detached fields in the actual class defining them, not the derived ones
if c is cls:
# TODO lift this restriction if required (requires change in dbschemegen as well)
assert c.derived or not p.is_single, \
f"property {p.name} in concrete class marked as detached but not optional"
detached_fields.append(_get_field(c, p))
elif not cls.derived:
# for non-detached ones, only generate fields in the concrete classes
fields.append(_get_field(c, p))
return rust.Class(
name=name,
fields=fields,
detached_fields=detached_fields,
# remove duplicates but preserve ordering
# (`dict` preserves insertion order while `set` doesn't)
ancestors=[*{a.name: None for a in _get_ancestors(cls, self._classmap)}],
entry_table=inflection.tableize(cls.name) if not cls.derived else None,
)
def get_classes(self):
ret = {"": []}
for k, cls in self._classmap.items():
if not cls.imported and not cls.synth:
ret.setdefault(cls.group, []).append(self._get_class(cls.name))
elif cls.imported:
ret[""].append(rust.Class(name=cls.name))
return ret
def generate(opts, renderer):
assert opts.rust_output
processor = Processor(schemaloader.load_file(opts.schema))
out = opts.rust_output
groups = set()
with renderer.manage(generated=out.rglob("*.rs"),
stubs=(),
registry=out / ".generated.list",
force=opts.force) as renderer:
for group, classes in processor.get_classes().items():
group = group or "top"
groups.add(group)
renderer.render(
rust.ClassList(
classes,
opts.schema,
),
out / f"{group}.rs",
)
renderer.render(
rust.ModuleList(
groups,
opts.schema,
),
out / f"mod.rs",
)