-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathrusttestgen.py
83 lines (70 loc) · 2.62 KB
/
rusttestgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import dataclasses
import typing
from collections.abc import Iterable
import inflection
from misc.codegen.loaders import schemaloader
from . import qlgen
@dataclasses.dataclass
class Param:
name: str
type: str
first: bool = False
@dataclasses.dataclass
class Function:
name: str
signature: str
@dataclasses.dataclass
class TestCode:
template: typing.ClassVar[str] = "rust_test_code"
code: str
function: Function | None = None
def _get_code(doc: list[str]) -> list[str]:
adding_code = False
has_code = False
code = []
for line in doc:
match line, adding_code:
case ("```", _) | ("```rust", _):
adding_code = not adding_code
has_code = True
case _, False:
code.append(f"// {line}")
case _, True:
code.append(line)
assert not adding_code, "Unterminated code block in docstring:\n " + "\n ".join(doc)
if has_code:
return code
return []
def generate(opts, renderer):
assert opts.ql_test_output
schema = schemaloader.load_file(opts.schema)
with renderer.manage(generated=opts.ql_test_output.rglob("gen_*.rs"),
stubs=(),
registry=opts.ql_test_output / ".generated_tests.list",
force=opts.force) as renderer:
for cls in schema.classes.values():
if cls.imported:
continue
if (qlgen.should_skip_qltest(cls, schema.classes) or
"rust_skip_doc_test" in cls.pragmas):
continue
code = _get_code(cls.doc)
for p in schema.iter_properties(cls.name):
if "rust_skip_doc_test" in p.pragmas:
continue
property_code = _get_code(p.description)
if property_code:
code.append(f"// # {p.name}")
code += property_code
if not code:
continue
test_name = inflection.underscore(cls.name)
signature = cls.pragmas.get("rust_doc_test_signature", "() -> ()")
fn = signature and Function(f"test_{test_name}", signature)
if fn:
indent = 4 * " "
code = [indent + l for l in code]
test_with_name = typing.cast(str, cls.pragmas.get("qltest_test_with"))
test_with = schema.classes[test_with_name] if test_with_name else cls
test = opts.ql_test_output / test_with.group / test_with.name / f"gen_{test_name}.rs"
renderer.render(TestCode(code="\n".join(code), function=fn), test)