Skip to content

Commit 5edccab

Browse files
authored
Extends Array API to EagerOrt (#18)
* Extends Array API to EagerOrt * fix empty shape * fix shape * fix azure * refactoring * fix command line * CI * fix CI * fix CI
1 parent 37fe094 commit 5edccab

21 files changed

+444
-93
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ _cache/*
88
dist/*
99
build/*
1010
.eggs/*
11+
.hypothesis/*
1112
*egg-info/*
1213
_doc/auto_examples/*
1314
_doc/examples/_cache/*

_unittests/onnx-numpy-skips.txt

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# API failures
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones_like
13+
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/onnx-ort-skips.txt

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Not implementated by onnxruntime
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones
13+
array_api_tests/test_creation_functions.py::test_ones_like
14+
array_api_tests/test_creation_functions.py::test_zeros
15+
array_api_tests/test_creation_functions.py::test_zeros_like
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import unittest
2+
from inspect import isfunction, ismethod
3+
import numpy as np
4+
from onnx_array_api.ext_test_case import ExtTestCase
5+
from onnx_array_api.array_api import onnx_numpy as xpn
6+
from onnx_array_api.array_api import onnx_ort as xpo
7+
8+
# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9+
# from onnx_array_api.ort.ort_tensors import EagerOrtTensor
10+
11+
12+
class TestArraysApis(ExtTestCase):
13+
def test_zeros_numpy_1(self):
14+
c = xpn.zeros(1)
15+
d = c.numpy()
16+
self.assertEqualArray(np.array([0], dtype=np.float32), d)
17+
18+
def test_zeros_ort_1(self):
19+
c = xpo.zeros(1)
20+
d = c.numpy()
21+
self.assertEqualArray(np.array([0], dtype=np.float32), d)
22+
23+
def test_ffinfo(self):
24+
dt = np.float32
25+
fi1 = np.finfo(dt)
26+
fi2 = xpn.finfo(dt)
27+
fi3 = xpo.finfo(dt)
28+
dt1 = fi1.dtype
29+
dt2 = fi2.dtype
30+
dt3 = fi3.dtype
31+
self.assertEqual(dt2, dt3)
32+
self.assertNotEqual(dt1.__class__, dt2.__class__)
33+
mi1 = fi1.min
34+
mi2 = fi2.min
35+
self.assertEqual(mi1, mi2)
36+
mi1 = fi1.smallest_normal
37+
mi2 = fi2.smallest_normal
38+
self.assertEqual(mi1, mi2)
39+
for n in dir(fi1):
40+
if n.startswith("__"):
41+
continue
42+
if n in {"machar"}:
43+
continue
44+
v1 = getattr(fi1, n)
45+
with self.subTest(att=n):
46+
v2 = getattr(fi2, n)
47+
v3 = getattr(fi3, n)
48+
if isfunction(v1) or ismethod(v1):
49+
try:
50+
v1 = v1()
51+
except TypeError:
52+
continue
53+
v2 = v2()
54+
v3 = v3()
55+
if v1 != v2:
56+
raise AssertionError(
57+
f"12: info disagree on name {n!r}: {v1} != {v2}, "
58+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
59+
f"ismethod={ismethod(v1)}."
60+
)
61+
if v2 != v3:
62+
raise AssertionError(
63+
f"23: info disagree on name {n!r}: {v2} != {v3}, "
64+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
65+
f"ismethod={ismethod(v1)}."
66+
)
67+
68+
def test_iiinfo(self):
69+
dt = np.int64
70+
fi1 = np.iinfo(dt)
71+
fi2 = xpn.iinfo(dt)
72+
fi3 = xpo.iinfo(dt)
73+
dt1 = fi1.dtype
74+
dt2 = fi2.dtype
75+
dt3 = fi3.dtype
76+
self.assertEqual(dt2, dt3)
77+
self.assertNotEqual(dt1.__class__, dt2.__class__)
78+
mi1 = fi1.min
79+
mi2 = fi2.min
80+
self.assertEqual(mi1, mi2)
81+
for n in dir(fi1):
82+
if n.startswith("__"):
83+
continue
84+
if n in {"machar"}:
85+
continue
86+
v1 = getattr(fi1, n)
87+
with self.subTest(att=n):
88+
v2 = getattr(fi2, n)
89+
v3 = getattr(fi3, n)
90+
if isfunction(v1) or ismethod(v1):
91+
try:
92+
v1 = v1()
93+
except TypeError:
94+
continue
95+
v2 = v2()
96+
v3 = v3()
97+
if v1 != v2:
98+
raise AssertionError(
99+
f"12: info disagree on name {n!r}: {v1} != {v2}, "
100+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
101+
f"ismethod={ismethod(v1)}."
102+
)
103+
if v2 != v3:
104+
raise AssertionError(
105+
f"23: info disagree on name {n!r}: {v2} != {v3}, "
106+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
107+
f"ismethod={ismethod(v1)}."
108+
)
109+
110+
111+
if __name__ == "__main__":
112+
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import numpy as np
33
from onnx_array_api.ext_test_case import ExtTestCase
44
from onnx_array_api.array_api import onnx_numpy as xp
5-
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
5+
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor as EagerTensor
66

77

88
class TestOnnxNumpy(ExtTestCase):
99
def test_abs(self):
10-
c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64))
10+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
1111
mat = xp.zeros(c, dtype=xp.int64)
1212
matnp = mat.numpy()
1313
self.assertEqual(matnp.shape, (4, 5))
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import numpy as np
3+
from onnx_array_api.ext_test_case import ExtTestCase
4+
from onnx_array_api.array_api import onnx_ort as xp
5+
from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor
6+
7+
8+
class TestOnnxOrt(ExtTestCase):
9+
def test_abs(self):
10+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
11+
mat = xp.zeros(c, dtype=xp.int64)
12+
matnp = mat.numpy()
13+
self.assertEqual(matnp.shape, (4, 5))
14+
self.assertNotEmpty(matnp[0, 0])
15+
a = xp.absolute(mat)
16+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

_unittests/ut_ort/test_ort_tensor.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import unittest
22
from contextlib import redirect_stdout
33
from io import StringIO
4-
54
import numpy as np
65
from onnx.defs import onnx_opset_version
76
from onnx.reference import ReferenceEvaluator
87
from onnxruntime import InferenceSession
9-
108
from onnx_array_api.ext_test_case import ExtTestCase
119
from onnx_array_api.npx import eager_onnx, jit_onnx
1210
from onnx_array_api.npx.npx_functions import absolute as absolute_inline
1311
from onnx_array_api.npx.npx_functions import cdist as cdist_inline
1412
from onnx_array_api.npx.npx_functions_test import absolute
15-
from onnx_array_api.npx.npx_types import Float32, Float64
13+
from onnx_array_api.npx.npx_functions import copy as copy_inline
14+
from onnx_array_api.npx.npx_types import Float32, Float64, DType
1615
from onnx_array_api.npx.npx_var import Input
1716
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, JitOrtTensor, OrtTensor
1817

@@ -193,6 +192,49 @@ def impl(xa, xb):
193192
if len(pieces) > 2:
194193
raise AssertionError(f"Function is not using argument:\n{onx}")
195194

195+
def test_astype(self):
196+
f = absolute_inline(copy_inline(Input("A")).astype(np.float32))
197+
onx = f.to_onnx(constraints={"A": Float64[None]})
198+
x = np.array([[-5, 6]], dtype=np.float64)
199+
z = np.abs(x.astype(np.float32))
200+
ref = InferenceSession(
201+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
202+
)
203+
got = ref.run(None, {"A": x})
204+
self.assertEqualArray(z, got[0])
205+
206+
def test_astype0(self):
207+
f = absolute_inline(copy_inline(Input("A")).astype(np.float32))
208+
onx = f.to_onnx(constraints={"A": Float64[None]})
209+
x = np.array(-5, dtype=np.float64)
210+
z = np.abs(x.astype(np.float32))
211+
ref = InferenceSession(
212+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
213+
)
214+
got = ref.run(None, {"A": x})
215+
self.assertEqualArray(z, got[0])
216+
217+
def test_eager_ort_cast(self):
218+
def impl(A):
219+
return A.astype(DType("FLOAT"))
220+
221+
e = eager_onnx(impl)
222+
self.assertEqual(len(e.versions), 0)
223+
224+
# Float64
225+
x = np.array([0, 1, -2], dtype=np.float64)
226+
z = x.astype(np.float32)
227+
res = e(x)
228+
self.assertEqualArray(z, res)
229+
self.assertEqual(res.dtype, np.float32)
230+
231+
# again
232+
x = np.array(1, dtype=np.float64)
233+
z = x.astype(np.float32)
234+
res = e(x)
235+
self.assertEqualArray(z, res)
236+
self.assertEqual(res.dtype, np.float32)
237+
196238

197239
if __name__ == "__main__":
198240
# TestNpx().test_eager_numpy()

azure-pipelines.yml

+9-10
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ jobs:
110110
displayName: 'Install tools'
111111
- script: pip install -r requirements.txt
112112
displayName: 'Install Requirements'
113+
- script: pip install onnxruntime
114+
displayName: 'Install onnxruntime'
113115
- script: python setup.py install
114116
displayName: 'Install onnx_array_api'
115117
- script: |
@@ -129,8 +131,13 @@ jobs:
129131
- script: |
130132
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
131133
cd array-api-tests
132-
python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros
133-
displayName: "test_creation_functions.py::test_zeros"
134+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v
135+
displayName: "numpy test_creation_functions.py"
136+
- script: |
137+
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
138+
cd array-api-tests
139+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v
140+
displayName: "ort test_creation_functions.py"
134141
#- script: |
135142
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
136143
# cd array-api-tests
@@ -246,16 +253,8 @@ jobs:
246253
displayName: 'export'
247254
- script: gcc --version
248255
displayName: 'gcc version'
249-
- script: brew install llvm
250-
displayName: 'install llvm'
251-
- script: brew install libomp
252-
displayName: 'Install omp'
253-
- script: brew install p7zip
254-
displayName: 'Install p7zip'
255256
- script: python -m pip install --upgrade pip setuptools wheel
256257
displayName: 'Install tools'
257-
- script: brew install pybind11
258-
displayName: 'Install pybind11'
259258
- script: pip install -r requirements.txt
260259
displayName: 'Install Requirements'
261260
- script: pip install -r requirements-dev.txt

onnx_array_api/_helpers.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
from typing import Any
3+
from onnx import helper, TensorProto
4+
5+
6+
def np_dtype_to_tensor_dtype(dtype: Any):
7+
"""
8+
Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
9+
"""
10+
try:
11+
dt = helper.np_dtype_to_tensor_dtype(dtype)
12+
except KeyError:
13+
if dtype == np.float32:
14+
dt = TensorProto.FLOAT
15+
elif dtype == np.float64:
16+
dt = TensorProto.DOUBLE
17+
elif dtype == np.int64:
18+
dt = TensorProto.INT64
19+
elif dtype == np.int32:
20+
dt = TensorProto.INT32
21+
elif dtype == np.int16:
22+
dt = TensorProto.INT16
23+
elif dtype == np.int8:
24+
dt = TensorProto.INT8
25+
elif dtype == np.uint64:
26+
dt = TensorProto.UINT64
27+
elif dtype == np.uint32:
28+
dt = TensorProto.UINT32
29+
elif dtype == np.uint16:
30+
dt = TensorProto.UINT16
31+
elif dtype == np.uint8:
32+
dt = TensorProto.UINT8
33+
elif dtype == np.float16:
34+
dt = TensorProto.FLOAT16
35+
elif dtype in (bool, np.bool_):
36+
dt = TensorProto.BOOL
37+
elif dtype in (str, np.str_):
38+
dt = TensorProto.STRING
39+
elif dtype is int:
40+
dt = TensorProto.INT64
41+
elif dtype is float:
42+
dt = TensorProto.FLOAT64
43+
else:
44+
raise KeyError(f"Unable to guess type for dtype={dtype}.")
45+
return dt

onnx_array_api/array_api/__init__.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
1+
import numpy as np
12
from onnx import TensorProto
3+
from .._helpers import np_dtype_to_tensor_dtype
24
from ..npx.npx_types import DType
35

46

7+
def _finfo(dtype):
8+
"""
9+
Similar to :class:`numpy.finfo`.
10+
"""
11+
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
12+
res = np.finfo(dt)
13+
d = res.__dict__.copy()
14+
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
15+
nres = type("finfo", (res.__class__,), d)
16+
setattr(nres, "smallest_normal", res.smallest_normal)
17+
setattr(nres, "tiny", res.tiny)
18+
return nres
19+
20+
21+
def _iinfo(dtype):
22+
"""
23+
Similar to :class:`numpy.finfo`.
24+
"""
25+
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
26+
res = np.iinfo(dt)
27+
d = res.__dict__.copy()
28+
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
29+
nres = type("finfo", (res.__class__,), d)
30+
setattr(nres, "min", res.min)
31+
setattr(nres, "max", res.max)
32+
return nres
33+
34+
535
def _finalize_array_api(module):
36+
"""
37+
Adds common attributes to Array API defined in this modules
38+
such as types.
39+
"""
640
module.float16 = DType(TensorProto.FLOAT16)
741
module.float32 = DType(TensorProto.FLOAT)
842
module.float64 = DType(TensorProto.DOUBLE)
@@ -17,3 +51,5 @@ def _finalize_array_api(module):
1751
module.bfloat16 = DType(TensorProto.BFLOAT16)
1852
setattr(module, "bool", DType(TensorProto.BOOL))
1953
setattr(module, "str", DType(TensorProto.STRING))
54+
setattr(module, "finfo", _finfo)
55+
setattr(module, "iinfo", _iinfo)

0 commit comments

Comments
 (0)