Skip to content

Commit 4f5b006

Browse files
committed
add ort_optimized_model
1 parent 464ab7d commit 4f5b006

File tree

9 files changed

+110
-16
lines changed

9 files changed

+110
-16
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ build/*
1010
*egg-info/*
1111
_doc/auto_examples/*
1212
_doc/examples/plot_*.png
13+
_doc/_static/require.js
14+
_doc/_static/viz.js

_doc/examples/plot_onnxruntime.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,21 @@ def myloss(x, y):
5656
yort = OrtTensor.from_array(y)
5757

5858

59-
def loop():
60-
for _ in range(1000):
59+
def loop_ort(n):
60+
for _ in range(n):
6161
ort_myloss(xort, yort)
6262

6363

64+
def loop_numpy(n):
65+
for _ in range(n):
66+
myloss(x, y)
67+
68+
69+
def loop(n=1000):
70+
loop_numpy(n)
71+
loop_ort(n)
72+
73+
6474
ps = profile(loop)[0]
6575
root, nodes = profile2graph(ps, clean_text=lambda x: x.split("/")[-1])
6676
text = root.to_text()
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
import numpy as np
4+
from onnx.defs import onnx_opset_version
5+
from onnx_array_api.npx import absolute, jit_onnx
6+
from onnx_array_api.ext_test_case import ExtTestCase
7+
from onnx_array_api.ort.ort_optimizers import ort_optimized_model
8+
9+
10+
DEFAULT_OPSET = onnx_opset_version()
11+
12+
13+
class TestOrtOptimizer(ExtTestCase):
14+
def test_ort_optimizers(self):
15+
def l1_loss(x, y):
16+
return absolute(x - y).sum()
17+
18+
def l2_loss(x, y):
19+
return ((x - y) ** 2).sum()
20+
21+
def myloss(x, y):
22+
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
23+
24+
jitted_myloss = jit_onnx(myloss)
25+
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
26+
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
27+
jitted_myloss(x, y)
28+
onx = jitted_myloss.get_onnx()
29+
self.assertRaise(lambda: ort_optimized_model(onx, "NO"), ValueError)
30+
optimized = ort_optimized_model(onx)
31+
self.assertIn('op_type: "Squeeze"', str(optimized))
32+
self.assertIn("initializer {", str(optimized))
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main(verbosity=2)

_unittests/ut_ort/test_ort_tensor.py

-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from io import StringIO
44

55
import numpy as np
6-
import scipy
76
from onnx.defs import onnx_opset_version
87
from onnx.reference import ReferenceEvaluator
98
from onnxruntime import InferenceSession
@@ -21,7 +20,6 @@
2120

2221

2322
class TestOrtTensor(ExtTestCase):
24-
@unittest.skipIf(InferenceSession is None, reason="onnxruntime is needed.")
2523
def test_eager_numpy_type_ort(self):
2624
def impl(A):
2725
self.assertIsInstance(A, EagerOrtTensor)
@@ -47,7 +45,6 @@ def impl(A):
4745
self.assertEqualArray(z, res.numpy())
4846
self.assertEqual(res.numpy().dtype, np.float64)
4947

50-
@unittest.skipIf(InferenceSession is None, reason="onnxruntime is needed.")
5148
def test_eager_numpy_type_ort_op(self):
5249
def impl(A):
5350
self.assertIsInstance(A, EagerOrtTensor)
@@ -71,7 +68,6 @@ def impl(A):
7168
self.assertEqualArray(z, res.numpy())
7269
self.assertEqual(res.numpy().dtype, np.float64)
7370

74-
@unittest.skipIf(InferenceSession is None, reason="onnxruntime is needed.")
7571
def test_eager_ort(self):
7672
def impl(A):
7773
print("A")
@@ -145,8 +141,6 @@ def impl(A):
145141
self.assertEqual(tuple(res.shape()), z.shape)
146142
self.assertStartsWith("A\nB\nC\n", text)
147143

148-
@unittest.skipIf(InferenceSession is None, reason="onnxruntime is not available")
149-
@unittest.skipIf(scipy is None, reason="scipy is not installed.")
150144
def test_cdist_com_microsoft(self):
151145
from scipy.spatial.distance import cdist as scipy_cdist
152146

onnx_array_api/cache.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from pathlib import Path
2+
3+
4+
def get_cache_file(filename: str, remove: bool = False):
5+
"""
6+
Returns a name in the cache folder `~/.onnx-array-api`.
7+
8+
:param filename: filename
9+
:param remove: remove if exists
10+
:return: full filename
11+
"""
12+
home = Path.home()
13+
folder = home / ".onnx-array-api"
14+
if not folder.exists():
15+
folder.mkdir()
16+
name = folder / filename
17+
if name.exists():
18+
name.unlink()
19+
return name

onnx_array_api/ort/ort_optimizers.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Union
2+
from onnx import ModelProto, load
3+
from onnxruntime import InferenceSession, SessionOptions
4+
from onnxruntime.capi._pybind_state import GraphOptimizationLevel
5+
from ..cache import get_cache_file
6+
7+
8+
def ort_optimized_model(
9+
onx: Union[str, ModelProto], level: str = "ORT_ENABLE_ALL"
10+
) -> Union[str, ModelProto]:
11+
"""
12+
Returns the optimized model used by onnxruntime before
13+
running computing the inference.
14+
15+
:param onx: ModelProto
16+
:param level: optimization level, `'ORT_ENABLE_BASIC'`,
17+
`'ORT_ENABLE_EXTENDED'`, `'ORT_ENABLE_ALL'`
18+
:return: optimized model
19+
"""
20+
glevel = getattr(GraphOptimizationLevel, level, None)
21+
if glevel is None:
22+
raise ValueError(
23+
f"Unrecognized level {level!r} among {dir(GraphOptimizationLevel)}."
24+
)
25+
26+
cache = get_cache_file("ort_optimized_model.onnx", remove=True)
27+
so = SessionOptions()
28+
so.graph_optimization_level = glevel
29+
so.optimized_model_filepath = str(cache)
30+
InferenceSession(onx if isinstance(onx, str) else onx.SerializeToString(), so)
31+
if not cache.exists():
32+
raise RuntimeError(f"The optimized model {str(cache)!r} not found.")
33+
if isinstance(onx, str):
34+
return str(cache)
35+
opt_onx = load(str(cache))
36+
cache.unlink()
37+
return opt_onx

onnx_array_api/ort/ort_tensors.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
"""
2-
This is an example of a backend for classes :class:`JitOnnx` and :class:`JitEager`
3-
using onnxruntime as a runtime. It is provided as an example.
4-
"""
51
from typing import Any, Callable, List, Optional, Tuple, Union
62

73
import numpy as np

onnx_array_api/plotting/dot_plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def dot_label(text):
254254
)
255255
static_inputs.append(out)
256256

257-
if node.name.strip() == "":
257+
if node.name.strip() == "" or node.name in fill_names:
258258
name = node.op_type
259259
iname = 1
260260
while name in fill_names:
@@ -329,7 +329,7 @@ def dot_label(text):
329329
exp.append(
330330
f" {dot_name(prefix)}{dot_name(node.name)} "
331331
f'[shape=box style="filled,rounded" color=orange '
332-
f'label="{node.op_type}\\n({dot_name(node.name)}){satts}" '
332+
f'label="{node.op_type}{satts}" '
333333
f"fontsize={fontsize}];"
334334
)
335335

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
requirements = ["numpy", "scipy", "onnx"]
2525

2626
try:
27-
with open(os.path.join(here, "README.rst"), "r", encoding='utf-8') as f:
28-
long_description = "onnx-array-api:" + f.read().split('onnx-array-api:')[1]
27+
with open(os.path.join(here, "README.rst"), "r", encoding="utf-8") as f:
28+
long_description = "onnx-array-api:" + f.read().split("onnx-array-api:")[1]
2929
except FileNotFoundError:
3030
long_description = ""
3131

0 commit comments

Comments
 (0)