Skip to content

Commit 53506d1

Browse files
authored
First draft to export to GraphBuilder (#83)
* export to builder * doc * fix unit test * fix order * fix initializer * fix ut * fix opset
1 parent a54de21 commit 53506d1

File tree

8 files changed

+354
-5
lines changed

8 files changed

+354
-5
lines changed

CHANGELOGS.rst

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
Change Logs
22
===========
33

4-
0.2.0
4+
0.3.0
55
+++++
66

7+
* :pr:`79`: first draft to export to GraphBuilder
78
* :pr:`77`: supports ConcatOfShape and Slice with the light API
9+
10+
0.2.0
11+
+++++
12+
813
* :pr:`76`, :pr:`79`: add a mode to compare models without execution
914
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
1015
* :pr:`71`: adds tools to compare two onnx graphs

_unittests/ut_translate_api/test_translate.py

-1
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,4 @@ def test_aionnxml(self):
221221

222222

223223
if __name__ == "__main__":
224-
TestTranslate().test_export_if()
225224
unittest.main(verbosity=2)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import unittest
2+
from textwrap import dedent
3+
import numpy as np
4+
from onnx import ModelProto, TensorProto
5+
from onnx.checker import check_model
6+
from onnx.defs import onnx_opset_version
7+
from onnx.reference import ReferenceEvaluator
8+
from onnx_array_api.ext_test_case import ExtTestCase
9+
from onnx_array_api.light_api import start
10+
from onnx_array_api.graph_api import GraphBuilder
11+
from onnx_array_api.translate_api import translate
12+
13+
14+
OPSET_API = min(19, onnx_opset_version() - 1)
15+
16+
17+
class TestTranslateBuilder(ExtTestCase):
18+
def setUp(self):
19+
self.maxDiff = None
20+
21+
def test_exp(self):
22+
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23+
self.assertIsInstance(onx, ModelProto)
24+
self.assertIn("Exp", str(onx))
25+
ref = ReferenceEvaluator(onx)
26+
a = np.arange(10).astype(np.float32)
27+
got = ref.run(None, {"X": a})[0]
28+
self.assertEqualArray(np.exp(a), got)
29+
30+
code = translate(onx, api="builder")
31+
expected = dedent(
32+
"""
33+
def light_api(
34+
op: "GraphBuilder",
35+
X: "FLOAT[]",
36+
):
37+
Y = op.Exp(X)
38+
op.Identity(Y, outputs=["Y"])
39+
return Y
40+
41+
g = GraphBuilder({'': 19})
42+
g.make_tensor_input("X", TensorProto.FLOAT, ())
43+
light_api(g.op, "X")
44+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
45+
model = g.to_onnx()
46+
"""
47+
).strip("\n")
48+
self.assertEqual(expected, code.strip("\n"))
49+
50+
def light_api(
51+
op: "GraphBuilder",
52+
X: "FLOAT[]", # noqa: F722
53+
):
54+
Y = op.Exp(X)
55+
op.Identity(Y, outputs=["Y"])
56+
return Y
57+
58+
g2 = GraphBuilder({"": 19})
59+
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
60+
light_api(g2.op, "X")
61+
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
62+
onx2 = g2.to_onnx()
63+
64+
ref = ReferenceEvaluator(onx2)
65+
a = np.arange(10).astype(np.float32)
66+
got = ref.run(None, {"X": a})[0]
67+
self.assertEqualArray(np.exp(a), got)
68+
69+
def test_zdoc(self):
70+
onx = (
71+
start(opset=19)
72+
.vin("X")
73+
.reshape((-1, 1))
74+
.Transpose(perm=[1, 0])
75+
.rename("Y")
76+
.vout()
77+
.to_onnx()
78+
)
79+
code = translate(onx, api="builder")
80+
expected = dedent(
81+
"""
82+
def light_api(
83+
op: "GraphBuilder",
84+
X: "FLOAT[]",
85+
):
86+
r = np.array([-1, 1], dtype=np.int64)
87+
r0_0 = op.Reshape(X, r)
88+
Y = op.Transpose(r0_0, perm=[1, 0])
89+
op.Identity(Y, outputs=["Y"])
90+
return Y
91+
92+
g = GraphBuilder({'': 19})
93+
g.make_tensor_input("X", TensorProto.FLOAT, ())
94+
light_api(g.op, "X")
95+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
96+
model = g.to_onnx()
97+
"""
98+
).strip("\n")
99+
self.maxDiff = None
100+
self.assertEqual(expected, code.strip("\n"))
101+
102+
def light_api(
103+
op: "GraphBuilder",
104+
X: "FLOAT[]", # noqa: F722
105+
):
106+
r = np.array([-1, 1], dtype=np.int64)
107+
r0_0 = op.Reshape(X, r)
108+
Y = op.Transpose(r0_0, perm=[1, 0])
109+
op.Identity(Y, outputs=["Y"])
110+
return Y
111+
112+
g = GraphBuilder({"": 21})
113+
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
114+
light_api(g.op, X)
115+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
116+
model = g.to_onnx()
117+
self.assertNotEmpty(model)
118+
check_model(model)
119+
120+
121+
if __name__ == "__main__":
122+
unittest.main(verbosity=2)

onnx_array_api/graph_api/graph_builder.py

+12
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __getattr__(self, name):
119119
except AttributeError as e:
120120
raise AttributeError(f"Unable to access attribute {name!r}.") from e
121121

122+
def Initializer(
123+
self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
124+
) -> str:
125+
"""
126+
Creates an initializer.
127+
128+
:param init: value
129+
:param name: name if value is not a TensorProto
130+
:return: its name
131+
"""
132+
return self.builder.make_initializer(init, name=name, exists=True)
133+
122134
def make_node(
123135
self,
124136
op_type: str,

onnx_array_api/translate_api/__init__.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from onnx import ModelProto
22
from .translate import Translater
33
from .inner_emitter import InnerEmitter
4+
from .builder_emitter import BuilderEmitter
45

56

67
def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
@@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1415
default is `"light"` and this is handle by class
1516
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1617
another value is `"onnx"` which is the inner API implemented
17-
in onnx package.
18+
in onnx package, `"builder"` follows the syntax for the
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`
1820
:return: code
1921
2022
.. runpython::
@@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
3537
code = translate(onx)
3638
print(code)
3739
38-
The inner API from onnx packahe is also available.
40+
The inner API from onnx package is also available.
3941
4042
.. runpython::
4143
:showcode:
@@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
5456
)
5557
code = translate(onx, api="onnx")
5658
print(code)
59+
60+
The :class:`GraphBuilder
61+
<onnx_array_api.graph_api.GraphBuilder>` API returns this:
62+
63+
.. runpython::
64+
:showcode:
65+
66+
from onnx_array_api.light_api import start
67+
from onnx_array_api.translate_api import translate
68+
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
print(code)
5780
"""
5881
if api == "light":
5982
tr = Translater(proto)
6083
return tr.export(single_line=single_line, as_str=True)
6184
if api == "onnx":
6285
tr = Translater(proto, emitter=InnerEmitter())
6386
return tr.export(as_str=True)
87+
if api == "builder":
88+
tr = Translater(proto, emitter=BuilderEmitter())
89+
return tr.export(as_str=True)
6490
raise ValueError(f"Unexpected value {api!r} for api.")

onnx_array_api/translate_api/base_emitter.py

+28
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class EventType(IntEnum):
2121
FUNCTION_OUTPUT = 12
2222
FUNCTION_ATTRIBUTES = 13
2323
TO_ONNX_FUNCTION = 14
24+
BEGIN_SIGNATURE = 15
25+
END_SIGNATURE = 16
26+
BEGIN_RETURN = 17
27+
END_RETURN = 18
2428

2529
@classmethod
2630
def to_str(cls, self) -> str:
@@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
8488
if event == EventType.FUNCTION_ATTRIBUTES:
8589
return self._emit_function_attributes(**kwargs)
8690

91+
if event == EventType.BEGIN_SIGNATURE:
92+
return self._emit_begin_signature(**kwargs)
93+
94+
if event == EventType.END_SIGNATURE:
95+
return self._emit_end_signature(**kwargs)
96+
97+
if event == EventType.BEGIN_RETURN:
98+
return self._emit_begin_return(**kwargs)
99+
100+
if event == EventType.END_RETURN:
101+
return self._emit_end_return(**kwargs)
102+
87103
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
88104

89105
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
222238
raise NotImplementedError(
223239
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
224240
)
241+
242+
def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
243+
return []
244+
245+
def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
246+
return []
247+
248+
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
249+
return []
250+
251+
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
252+
return []

0 commit comments

Comments
 (0)