-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlight_emitter.py
106 lines (93 loc) · 3.66 KB
/
light_emitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Dict, List
from ..annotations import ELEMENT_TYPE_NAME
from .base_emitter import BaseEmitter
class LightEmitter(BaseEmitter):
"""
Converts event into proper code.
"""
def join(self, rows: List[str], single_line: bool = False) -> str:
"Join the rows"
if single_line:
return ".".join(rows)
return "".join(["(\n ", "\n .".join(rows), "\n)"])
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
opsets = kwargs.get("opsets", {})
opset = opsets.get("", None)
if opset is not None:
del opsets[""]
args = []
if opset:
args.append(f"opset={opset}")
if opsets:
args.append(f"opsets={opsets}")
return [f"start({', '.join(args)})"]
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
return ["to_onnx()"]
def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
return []
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
value = kwargs["value"]
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
return [
f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
f"rename({name!r})",
]
def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
name = kwargs["name"]
elem_type = kwargs.get("elem_type", None)
shape = kwargs.get("shape", None)
if elem_type and shape:
return [
f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
f"shape={shape!r})"
]
if elem_type:
return [
f"vin({name!r}, elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})"
]
return [f"vin({name!r})"]
def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
inst = []
if "name" in kwargs:
name = kwargs["name"]
inst.append(f"bring({name!r})")
elem_type = kwargs.get("elem_type", None)
shape = kwargs.get("shape", None)
if elem_type and shape:
inst.append(
f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]}, "
f"shape={shape!r})"
)
elif elem_type:
inst.append(f"vout(elem_type=TensorProto.{ELEMENT_TYPE_NAME[elem_type]})")
else:
inst.append("vout()")
return inst
def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
op_type = kwargs["op_type"]
inputs = kwargs["inputs"]
outputs = kwargs["outputs"]
if kwargs.get("domain", "") != "":
domain = kwargs["domain"]
op_type = f"{domain}.{op_type}"
atts = kwargs.get("atts", {})
args = []
for k, v in atts.items():
before, vatt = self.render_attribute_value(v)
if before:
raise NotImplementedError("Graph attribute not supported yet.")
args.append(f"{k}={vatt}")
str_inputs = ", ".join([f"{i!r}" for i in inputs])
inst = [f"bring({str_inputs})", f"{op_type}({', '.join(args)})"]
if len(outputs) == 1:
inst.append(f"rename({outputs[0]!r})")
else:
str_outputs = ", ".join([f"{o!r}" for o in outputs])
inst.append(f"rename({str_outputs})")
return inst