Skip to content

Commit 1be44a7

Browse files
sdpythonxadupre
andauthored
Enables linspace (#30)
* Enables test_asarray_scalars * Add support for array api linspace * lint * improves consistency for linspace * fix linspace * disable asarrays_arrays * fix strategies * aapi --------- Co-authored-by: Xavier Dupre <[email protected]>
1 parent 9b0b5d6 commit 1be44a7

12 files changed

+409
-19
lines changed

_unittests/onnx-numpy-skips.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
33
# uses __setitem__
4-
array_api_tests/test_creation_functions.py::test_asarray_arrays
4+
# array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
7-
array_api_tests/test_creation_functions.py::test_linspace
7+
# array_api_tests/test_creation_functions.py::test_linspace
88
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_hypothesis_array_api.py

+51-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from operator import mul
77
from hypothesis import given
8-
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
99
from onnx_array_api.array_api import onnx_numpy as onxp
1010
from hypothesis import strategies
1111
from hypothesis.extra import array_api
@@ -207,11 +207,58 @@ def fctonx(n_rows, n_cols, kw):
207207
fctonx()
208208
self.assertEqual(len(args_onxp), len(args_np))
209209

210+
@ignore_warnings(UserWarning)
211+
def test_square_shared_types(self):
212+
dtypes = self.onxps.scalar_dtypes()
213+
shared_dtypes = strategies.shared(dtypes, key="dtype")
214+
215+
def shapes(**kw):
216+
kw.setdefault("min_dims", 0)
217+
kw.setdefault("min_side", 0)
218+
return self.onxps.array_shapes(**kw).filter(
219+
lambda shape: prod(i for i in shape if i) < self.MAX_ARRAY_SIZE
220+
)
221+
222+
@strategies.composite
223+
def kwargs(draw, **kw):
224+
result = {}
225+
for k, strat in kw.items():
226+
if draw(strategies.booleans()):
227+
result[k] = draw(strat)
228+
return result
229+
230+
@strategies.composite
231+
def full_like_fill_values(draw):
232+
kw = draw(
233+
strategies.shared(
234+
kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes()),
235+
key="full_like_kw",
236+
)
237+
)
238+
dtype = kw.get("dtype", None) or draw(shared_dtypes)
239+
return draw(self.onxps.from_dtype(dtype))
240+
241+
args = []
242+
sh = shapes()
243+
xa = self.onxps.arrays(dtype=shared_dtypes, shape=sh)
244+
fu = full_like_fill_values()
245+
kws = strategies.shared(
246+
kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes()),
247+
key="full_like_kw",
248+
)
249+
250+
@given(x=xa, fill_value=fu, kw=kws)
251+
def fctonp(x, fill_value, kw):
252+
args.append((x, fill_value, kw))
253+
254+
fctonp()
255+
self.assertEqual(len(args), 100)
256+
210257

211258
if __name__ == "__main__":
212-
# cl = TestHypothesisArraysApis()
213-
# cl.setUpClass()
214-
# cl.test_scalar_strategies()
259+
cl = TestHypothesisArraysApis()
260+
cl.setUpClass()
261+
cl.test_square_shared_types()
215262
# import logging
216263

217264
# logging.basicConfig(level=logging.DEBUG)

_unittests/ut_array_api/test_onnx_numpy.py

+94-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import unittest
33
import numpy as np
44
from onnx import TensorProto
5-
from onnx_array_api.ext_test_case import ExtTestCase
5+
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
66
from onnx_array_api.array_api import onnx_numpy as xp
77
from onnx_array_api.npx.npx_types import DType
88
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor as EagerTensor
9+
from onnx_array_api.npx.npx_functions import linspace as linspace_inline
10+
from onnx_array_api.npx.npx_types import Float64, Int64
11+
from onnx_array_api.npx.npx_var import Input
12+
from onnx_array_api.reference import ExtendedReferenceEvaluator
913

1014

1115
class TestOnnxNumpy(ExtTestCase):
@@ -22,6 +26,7 @@ def test_zeros(self):
2226
a = xp.absolute(mat)
2327
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
2428

29+
@ignore_warnings(DeprecationWarning)
2530
def test_arange_default(self):
2631
a = EagerTensor(np.array([0], dtype=np.int64))
2732
b = EagerTensor(np.array([2], dtype=np.int64))
@@ -30,6 +35,7 @@ def test_arange_default(self):
3035
self.assertEqual(matnp.shape, (2,))
3136
self.assertEqualArray(matnp, np.arange(0, 2).astype(np.int64))
3237

38+
@ignore_warnings(DeprecationWarning)
3339
def test_arange_step(self):
3440
a = EagerTensor(np.array([4], dtype=np.int64))
3541
s = EagerTensor(np.array([2], dtype=np.int64))
@@ -78,6 +84,7 @@ def test_full_bool(self):
7884
self.assertNotEmpty(matnp[0, 0])
7985
self.assertEqualArray(matnp, np.full((4, 5), False))
8086

87+
@ignore_warnings(DeprecationWarning)
8188
def test_arange_int00a(self):
8289
a = EagerTensor(np.array([0], dtype=np.int64))
8390
b = EagerTensor(np.array([0], dtype=np.int64))
@@ -89,6 +96,7 @@ def test_arange_int00a(self):
8996
expected = expected.astype(np.int64)
9097
self.assertEqualArray(matnp, expected)
9198

99+
@ignore_warnings(DeprecationWarning)
92100
def test_arange_int00(self):
93101
mat = xp.arange(0, 0)
94102
matnp = mat.numpy()
@@ -160,10 +168,94 @@ def test_eye_k(self):
160168
got = xp.eye(nr, k=1)
161169
self.assertEqualArray(expected, got.numpy())
162170

171+
def test_linspace_int(self):
172+
a = EagerTensor(np.array([0], dtype=np.int64))
173+
b = EagerTensor(np.array([6], dtype=np.int64))
174+
c = EagerTensor(np.array(3, dtype=np.int64))
175+
mat = xp.linspace(a, b, c)
176+
matnp = mat.numpy()
177+
expected = np.linspace(a.numpy(), b.numpy(), c.numpy()).astype(np.int64)
178+
self.assertEqualArray(expected, matnp)
179+
180+
def test_linspace_int5(self):
181+
a = EagerTensor(np.array([0], dtype=np.int64))
182+
b = EagerTensor(np.array([5], dtype=np.int64))
183+
c = EagerTensor(np.array(3, dtype=np.int64))
184+
mat = xp.linspace(a, b, c)
185+
matnp = mat.numpy()
186+
expected = np.linspace(a.numpy(), b.numpy(), c.numpy()).astype(np.int64)
187+
self.assertEqualArray(expected, matnp)
188+
189+
def test_linspace_float(self):
190+
a = EagerTensor(np.array([0.5], dtype=np.float64))
191+
b = EagerTensor(np.array([5.5], dtype=np.float64))
192+
c = EagerTensor(np.array(2, dtype=np.int64))
193+
mat = xp.linspace(a, b, c)
194+
matnp = mat.numpy()
195+
expected = np.linspace(a.numpy(), b.numpy(), c.numpy())
196+
self.assertEqualArray(expected, matnp)
197+
198+
def test_linspace_float_noendpoint(self):
199+
a = EagerTensor(np.array([0.5], dtype=np.float64))
200+
b = EagerTensor(np.array([5.5], dtype=np.float64))
201+
c = EagerTensor(np.array(2, dtype=np.int64))
202+
mat = xp.linspace(a, b, c, endpoint=0)
203+
matnp = mat.numpy()
204+
expected = np.linspace(a.numpy(), b.numpy(), c.numpy(), endpoint=0)
205+
self.assertEqualArray(expected, matnp)
206+
207+
@ignore_warnings((RuntimeWarning, DeprecationWarning)) # division by zero
208+
def test_linspace_zero(self):
209+
expected = np.linspace(0.0, 0.0, 0, endpoint=False)
210+
mat = xp.linspace(0.0, 0.0, 0, endpoint=False)
211+
matnp = mat.numpy()
212+
self.assertEqualArray(expected, matnp)
213+
214+
@ignore_warnings((RuntimeWarning, DeprecationWarning)) # division by zero
215+
def test_linspace_zero_one(self):
216+
expected = np.linspace(0.0, 0.0, 1, endpoint=True)
217+
218+
f = linspace_inline(Input("start"), Input("stop"), Input("num"))
219+
onx = f.to_onnx(
220+
constraints={
221+
"start": Float64[None],
222+
"stop": Float64[None],
223+
"num": Int64[None],
224+
(0, False): Float64[None],
225+
}
226+
)
227+
ref = ExtendedReferenceEvaluator(onx)
228+
got = ref.run(
229+
None,
230+
{
231+
"start": np.array(0, dtype=np.float64),
232+
"stop": np.array(0, dtype=np.float64),
233+
"num": np.array(1, dtype=np.int64),
234+
},
235+
)
236+
self.assertEqualArray(expected, got[0])
237+
238+
mat = xp.linspace(0.0, 0.0, 1, endpoint=True)
239+
matnp = mat.numpy()
240+
241+
self.assertEqualArray(expected, matnp)
242+
243+
def test_slice_minus_one(self):
244+
g = EagerTensor(np.array([0.0]))
245+
expected = g.numpy()[:-1]
246+
got = g[:-1]
247+
self.assertEqualArray(expected, got.numpy())
248+
249+
def test_linspace_bug1(self):
250+
expected = np.linspace(16777217.0, 0.0, 1)
251+
mat = xp.linspace(16777217.0, 0.0, 1)
252+
matnp = mat.numpy()
253+
self.assertEqualArray(expected, matnp)
254+
163255

164256
if __name__ == "__main__":
165257
# import logging
166258

167259
# logging.basicConfig(level=logging.DEBUG)
168-
TestOnnxNumpy().test_eye()
260+
TestOnnxNumpy().test_linspace_float_noendpoint()
169261
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from onnx.reference import ReferenceEvaluator
2121
from onnx.shape_inference import infer_shapes
2222

23-
from onnx_array_api.ext_test_case import ExtTestCase
23+
from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings
24+
from onnx_array_api.reference import ExtendedReferenceEvaluator
2425
from onnx_array_api.npx import ElemType, eager_onnx, jit_onnx
2526
from onnx_array_api.npx.npx_core_api import (
2627
cst,
@@ -60,6 +61,7 @@
6061
from onnx_array_api.npx.npx_functions import hstack as hstack_inline
6162
from onnx_array_api.npx.npx_functions import identity as identity_inline
6263
from onnx_array_api.npx.npx_functions import isnan as isnan_inline
64+
from onnx_array_api.npx.npx_functions import linspace as linspace_inline
6365
from onnx_array_api.npx.npx_functions import log as log_inline
6466
from onnx_array_api.npx.npx_functions import log1p as log1p_inline
6567
from onnx_array_api.npx.npx_functions import matmul as matmul_inline
@@ -1654,6 +1656,7 @@ def test_squeeze(self):
16541656
got = ref.run(None, {"A": x})
16551657
self.assertEqualArray(z, got[0])
16561658

1659+
@ignore_warnings(DeprecationWarning)
16571660
def test_squeeze_noaxis(self):
16581661
f = squeeze_inline(copy_inline(Input("A")))
16591662
self.assertIsInstance(f, Var)
@@ -2574,6 +2577,51 @@ def test_get_item_i8(self):
25742577
i = a[0]
25752578
self.assertEqualArray(i.numpy(), a.numpy()[0])
25762579

2580+
@ignore_warnings(RuntimeWarning)
2581+
def test_linspace_big_inline(self):
2582+
# linspace(5, 0, 1) --> [5] even with endpoint=True
2583+
f = linspace_inline(Input("A"), Input("B"), Input("C"))
2584+
self.assertIsInstance(f, Var)
2585+
onx = f.to_onnx(
2586+
constraints={
2587+
0: Int64[None],
2588+
1: Int64[None],
2589+
2: Int64[None],
2590+
(0, False): Int64[None],
2591+
}
2592+
)
2593+
2594+
start = np.array(16777217.0, dtype=np.float64)
2595+
stop = np.array(0.0, dtype=np.float64)
2596+
num = np.array(1, dtype=np.int64)
2597+
y = np.linspace(start, stop, num)
2598+
ref = ExtendedReferenceEvaluator(onx)
2599+
got = ref.run(None, {"A": start, "B": stop, "C": num})
2600+
self.assertEqualArray(y, got[0])
2601+
2602+
@ignore_warnings(RuntimeWarning)
2603+
def test_linspace_inline(self):
2604+
# linspace(0, 5, 1)
2605+
f = linspace_inline(Input("A"), Input("B"), Input("C"))
2606+
self.assertIsInstance(f, Var)
2607+
onx = f.to_onnx(
2608+
constraints={
2609+
0: Int64[None],
2610+
1: Int64[None],
2611+
2: Int64[None],
2612+
(0, False): Int64[None],
2613+
}
2614+
)
2615+
2616+
start = np.array(0, dtype=np.float64)
2617+
stop = np.array(5, dtype=np.float64)
2618+
num = np.array(1, dtype=np.int64)
2619+
y = np.linspace(start, stop, num)
2620+
ref = ExtendedReferenceEvaluator(onx)
2621+
got = ref.run(None, {"A": start, "B": stop, "C": num})
2622+
self.assertEqualArray(y, got[0])
2623+
25772624

25782625
if __name__ == "__main__":
2626+
TestNpx().test_linspace_inline()
25792627
unittest.main(verbosity=2)

onnx_array_api/array_api/__init__.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"isfinite",
2525
"isinf",
2626
"isnan",
27+
"linspace",
2728
"ones",
2829
"ones_like",
2930
"reshape",
@@ -40,11 +41,18 @@ def _finfo(dtype):
4041
"""
4142
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
4243
res = np.finfo(dt)
43-
d = res.__dict__.copy()
44+
d = {}
45+
for k, v in res.__dict__.items():
46+
if k.startswith("__"):
47+
continue
48+
if isinstance(v, (np.float32, np.float64, np.float16)):
49+
d[k] = float(v)
50+
else:
51+
d[k] = v
4452
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
4553
nres = type("finfo", (res.__class__,), d)
46-
setattr(nres, "smallest_normal", res.smallest_normal)
47-
setattr(nres, "tiny", res.tiny)
54+
setattr(nres, "smallest_normal", float(res.smallest_normal))
55+
setattr(nres, "tiny", float(res.tiny))
4856
return nres
4957

5058

@@ -54,11 +62,30 @@ def _iinfo(dtype):
5462
"""
5563
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
5664
res = np.iinfo(dt)
57-
d = res.__dict__.copy()
65+
d = {}
66+
for k, v in res.__dict__.items():
67+
if k.startswith("__"):
68+
continue
69+
if isinstance(
70+
v,
71+
(
72+
np.int16,
73+
np.int32,
74+
np.int64,
75+
np.uint16,
76+
np.uint32,
77+
np.uint64,
78+
np.int8,
79+
np.uint8,
80+
),
81+
):
82+
d[k] = int(v)
83+
else:
84+
d[k] = v
5885
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
59-
nres = type("finfo", (res.__class__,), d)
60-
setattr(nres, "min", res.min)
61-
setattr(nres, "max", res.max)
86+
nres = type("iinfo", (res.__class__,), d)
87+
setattr(nres, "min", int(res.min))
88+
setattr(nres, "max", int(res.max))
6289
return nres
6390

6491

0 commit comments

Comments
 (0)