Skip to content

Commit dd1876f

Browse files
committed
first step for the array api
1 parent d875d0d commit dd1876f

File tree

6 files changed

+221
-63
lines changed

6 files changed

+221
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import unittest
2+
import numpy as np
3+
from onnx.defs import onnx_opset_version
4+
from sklearn import config_context
5+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
6+
from onnx_array_api.ext_test_case import ExtTestCase
7+
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
8+
9+
10+
DEFAULT_OPSET = onnx_opset_version()
11+
12+
13+
class TestSklearnArrayAPI(ExtTestCase):
14+
def test_sklearn_array_api_linear_discriminant(self):
15+
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
16+
y = np.array([1, 1, 1, 2, 2, 2])
17+
ana = LinearDiscriminantAnalysis()
18+
ana = LinearDiscriminantAnalysis()
19+
ana.fit(X, y)
20+
expected = ana.predict(X)
21+
22+
new_x = EagerNumpyTensor(X)
23+
self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x))
24+
with config_context(array_api_dispatch=True):
25+
got = ana.predict(new_x)
26+
self.assertEqualArray(expected, got)
27+
28+
29+
if __name__ == "__main__":
30+
unittest.main(verbosity=2)

onnx_array_api/npx/npx_array_api.py

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ class ArrayApi:
1010
List of supported method by a tensor.
1111
"""
1212

13+
def __array_namespace__(self):
14+
"""
15+
Returns the module holding all the available functions.
16+
"""
17+
from onnx_array_api.npx import npx_functions
18+
19+
return npx_functions
20+
1321
def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
1422
raise NotImplementedError(
1523
f"Method {method_name!r} must be overwritten "

onnx_array_api/npx/npx_functions.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import numpy as np
44
from onnx import FunctionProto, ModelProto, NodeProto
55
from onnx.numpy_helper import from_array
66

77
from .npx_constants import FUNCTION_DOMAIN
88
from .npx_core_api import cst, make_tuple, npxapi_inline, var
9+
from .npx_tensors import ArrayApi
910
from .npx_types import (
1011
ElemType,
1112
OptParType,
@@ -155,6 +156,38 @@ def arctanh(
155156
return var(x, op="Atanh")
156157

157158

159+
def asarray(
160+
a: Any,
161+
dtype: Any = None,
162+
order: Optional[str] = None,
163+
like: Any = None,
164+
copy: bool = False,
165+
):
166+
"""
167+
Converts anything into an array.
168+
"""
169+
if dtype is not None:
170+
raise RuntimeError("Method 'astype' should be used to change the type.")
171+
if order is not None:
172+
raise NotImplementedError(f"order={order!r} not implemented.")
173+
if isinstance(a, ArrayApi):
174+
if copy:
175+
return a.__class__(a, copy=copy)
176+
return a
177+
raise NotImplementedError(f"asarray not implemented for type {type(a)}.")
178+
179+
180+
@npxapi_inline
181+
def astype(
182+
a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[int] = 1
183+
) -> TensorType[ElemType.numerics, "T2"]:
184+
"""
185+
Cast an array.
186+
"""
187+
g = a.astype(dtype)
188+
return g
189+
190+
158191
@npxapi_inline
159192
def cdist(
160193
xa: TensorType[ElemType.numerics, "T"],
@@ -412,6 +445,17 @@ def relu(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics,
412445
return var(x, op="Relu")
413446

414447

448+
@npxapi_inline
449+
def reshape(
450+
x: TensorType[ElemType.numerics, "T"], shape: TensorType[ElemType.int64, "I"]
451+
) -> TensorType[ElemType.numerics, "T"]:
452+
"See :func:`numpy.reshape`."
453+
if isinstance(shape, int):
454+
shape = cst(np.array([shape], dtype=np.int64))
455+
shape_reshaped = var(shape, cst(np.array([-1], dtype=np.int64)), op="Reshape")
456+
return var(x, shape_reshaped, op="Reshape")
457+
458+
415459
@npxapi_inline
416460
def round(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
417461
"See :func:`numpy.round`."

onnx_array_api/npx/npx_jit_eager.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ def _preprocess_constants(self, *args):
338338
elif isinstance(n, (int, float)):
339339
new_args.append(self.tensor_class(np.array(n)))
340340
modified = True
341+
elif n in (int, float):
342+
# usually used to cast
343+
new_args.append(n)
341344
elif n is None:
342345
new_args.append(n)
343346
else:
@@ -365,7 +368,9 @@ def __call__(self, *args, already_eager=False, **kwargs):
365368
if any(
366369
map(
367370
lambda t: t is not None
368-
and not isinstance(t, (EagerTensor, Cst, int, float, tuple, slice)),
371+
and not isinstance(
372+
t, (EagerTensor, Cst, int, float, tuple, slice, type)
373+
),
369374
args,
370375
)
371376
):

onnx_array_api/npx/npx_numpy_tensors.py

+9
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(self, tensor: np.ndarray):
7070
else:
7171
raise TypeError(f"A numpy array is expected not {type(tensor)}.")
7272

73+
def __repr__(self) -> str:
74+
"usual"
75+
return f"{self.__class__.__name__}({self._tensor!r})"
76+
7377
def numpy(self):
7478
"Returns the array converted into a numpy array."
7579
return self._tensor
@@ -107,6 +111,11 @@ def dims(self):
107111
return self._tensor.shape
108112
return (None,) + self._tensor.shape[1:]
109113

114+
@property
115+
def ndim(self):
116+
"Returns the number of dimensions (rank)."
117+
return len(self.shape)
118+
110119
@property
111120
def shape(self) -> Tuple[int, ...]:
112121
"Returns the shape of the tensor."

onnx_array_api/npx/npx_tensors.py

+123-61
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Any
22

3+
import numpy as np
4+
from onnx.helper import np_dtype_to_tensor_dtype
5+
36
from .npx_array_api import ArrayApi
47

58

@@ -59,6 +62,16 @@ def _getitem_impl_var(obj, index, method_name=None):
5962
meth = getattr(Var, method_name)
6063
return meth(obj, index)
6164

65+
@staticmethod
66+
def _astype_impl(x, dtype, method_name=None):
67+
# avoids circular imports.
68+
from .npx_var import Var
69+
70+
if not isinstance(x, Var):
71+
raise TypeError(f"Input 0 must be a Var not {type(x)}.")
72+
meth = getattr(Var, "astype")
73+
return meth(x, dtype)
74+
6275
@staticmethod
6376
def _getitem_impl_tuple(obj, index=None, method_name=None):
6477
# avoids circular imports.
@@ -69,13 +82,114 @@ def _getitem_impl_tuple(obj, index=None, method_name=None):
6982
meth = getattr(Var, method_name)
7083
return meth(obj, index)
7184

85+
def _generic_method_getitem(self, method_name, *args: Any, **kwargs: Any) -> Any:
86+
# avoids circular imports.
87+
from .npx_jit_eager import eager_onnx
88+
89+
if len(args) != 1:
90+
raise ValueError(
91+
f"Unexpected number of argument {len(args)}, it should be one."
92+
)
93+
if isinstance(args[0], tuple):
94+
eag = eager_onnx(
95+
EagerTensor._getitem_impl_tuple, self.__class__, bypass_eager=True
96+
)
97+
res = eag(self, index=args[0], method_name=method_name, already_eager=True)
98+
else:
99+
eag = eager_onnx(
100+
EagerTensor._getitem_impl_var, self.__class__, bypass_eager=True
101+
)
102+
res = eag(self, args[0], method_name=method_name, already_eager=True)
103+
if isinstance(res, tuple) and len(res) == 1:
104+
return res[0]
105+
return res
106+
107+
def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> Any:
108+
# avoids circular imports.
109+
from .npx_jit_eager import eager_onnx
110+
111+
if len(args) not in (0, 1):
112+
raise ValueError(
113+
f"An operator must have zero or one argument not {len(args)}."
114+
)
115+
if len(kwargs) not in (0, 1):
116+
raise ValueError(f"Operators do not support parameters {len(kwargs)}.")
117+
118+
# let's cast numpy arrays into constants.
119+
new_args = []
120+
for a in args:
121+
if isinstance(a, np.ndarray):
122+
new_args.append(self.__class__(a).astype(self.dtype))
123+
else:
124+
new_args.append(a)
125+
126+
eag = eager_onnx(EagerTensor._op_impl, self.__class__, bypass_eager=True)
127+
res = eag(self, *new_args, method_name=method_name, already_eager=True)
128+
if isinstance(res, tuple) and len(res) == 1:
129+
return res[0]
130+
return res
131+
132+
def _generic_method_reduce(self, method_name, *args: Any, **kwargs: Any) -> Any:
133+
# avoids circular imports.
134+
from .npx_jit_eager import eager_onnx
135+
136+
if len(args) not in (0, 1):
137+
raise ValueError(
138+
f"An operator must have zero or one argument not {len(args)}."
139+
)
140+
141+
if "axis" in kwargs:
142+
axes = kwargs["axis"]
143+
del kwargs["axis"]
144+
else:
145+
axes = None
146+
if axes is None:
147+
eag = eager_onnx(
148+
EagerTensor._reduce_impl_noaxes, self.__class__, bypass_eager=True
149+
)
150+
res = eag(self, method_name=method_name, already_eager=True, **kwargs)
151+
else:
152+
eag = eager_onnx(
153+
EagerTensor._reduce_impl, self.__class__, bypass_eager=True
154+
)
155+
res = eag(self, axes, method_name=method_name, already_eager=True, **kwargs)
156+
if isinstance(res, tuple) and len(res) == 1:
157+
return res[0]
158+
return res
159+
160+
@staticmethod
161+
def _np_dtype_to_tensor_dtype(dtype):
162+
if dtype == int:
163+
dtype = np.dtype("int64")
164+
elif dtype == float:
165+
dtype = np.dtype("float64")
166+
return np_dtype_to_tensor_dtype(dtype)
167+
168+
def _generic_method_astype(self, method_name, *args: Any, **kwargs: Any) -> Any:
169+
# avoids circular imports.
170+
from .npx_jit_eager import eager_onnx
171+
from .npx_var import Var
172+
173+
if len(args) != 1:
174+
raise ValueError(f"astype takes only one argument not {len(args)}.")
175+
176+
dtype = (
177+
args[0]
178+
if isinstance(args[0], (int, Var))
179+
else self._np_dtype_to_tensor_dtype(args[0])
180+
)
181+
eag = eager_onnx(EagerTensor._astype_impl, self.__class__, bypass_eager=True)
182+
res = eag(self, dtype, method_name=method_name, already_eager=True, **kwargs)
183+
if isinstance(res, tuple) and len(res) == 1:
184+
return res[0]
185+
return res
186+
72187
def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
73188
"""
74189
The method converts the method into an ONNX graph build by the
75190
corresponding method in class Var.
76191
"""
77192
# avoids circular imports.
78-
from .npx_jit_eager import eager_onnx
79193
from .npx_var import Var
80194

81195
if not hasattr(Var, method_name):
@@ -84,70 +198,18 @@ def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
84198
f"This method cannot be converted into an ONNX graph."
85199
)
86200
if method_name == "__getitem__":
87-
if len(args) != 1:
88-
raise ValueError(
89-
f"Unexpected number of argument {len(args)}, it should be one."
90-
)
91-
if isinstance(args[0], tuple):
92-
eag = eager_onnx(
93-
EagerTensor._getitem_impl_tuple, self.__class__, bypass_eager=True
94-
)
95-
res = eag(
96-
self, index=args[0], method_name=method_name, already_eager=True
97-
)
98-
else:
99-
eag = eager_onnx(
100-
EagerTensor._getitem_impl_var, self.__class__, bypass_eager=True
101-
)
102-
res = eag(self, args[0], method_name=method_name, already_eager=True)
103-
if isinstance(res, tuple) and len(res) == 1:
104-
return res[0]
105-
return res
201+
return self._generic_method_getitem(method_name, *args, **kwargs)
106202

107203
if method_name == "__setitem__":
108204
return ArrayApi.generic_method(self, method_name, *args, **kwargs)
109205

110-
if method_name.startswith("__") and method_name.endswith("__"):
111-
# An operator.
112-
if len(args) not in (0, 1):
113-
raise ValueError(
114-
f"An operator must have zero or one argument not {len(args)}."
115-
)
116-
if len(kwargs) not in (0, 1):
117-
raise ValueError(f"Operators do not support parameters {len(kwargs)}.")
118-
119-
eag = eager_onnx(EagerTensor._op_impl, self.__class__, bypass_eager=True)
120-
res = eag(self, *args, method_name=method_name, already_eager=True)
121-
if isinstance(res, tuple) and len(res) == 1:
122-
return res[0]
123-
return res
124-
125206
if method_name in {"mean", "sum", "min", "max", "prod"}:
126-
# ReduceFunction
127-
if len(args) not in (0, 1):
128-
raise ValueError(
129-
f"An operator must have zero or one argument not {len(args)}."
130-
)
131-
132-
if "axis" in kwargs:
133-
axes = kwargs["axis"]
134-
del kwargs["axis"]
135-
else:
136-
axes = None
137-
if axes is None:
138-
eag = eager_onnx(
139-
EagerTensor._reduce_impl_noaxes, self.__class__, bypass_eager=True
140-
)
141-
res = eag(self, method_name=method_name, already_eager=True, **kwargs)
142-
else:
143-
eag = eager_onnx(
144-
EagerTensor._reduce_impl, self.__class__, bypass_eager=True
145-
)
146-
res = eag(
147-
self, axes, method_name=method_name, already_eager=True, **kwargs
148-
)
149-
if isinstance(res, tuple) and len(res) == 1:
150-
return res[0]
151-
return res
207+
return self._generic_method_reduce(method_name, *args, **kwargs)
208+
209+
if method_name == "astype":
210+
return self._generic_method_astype(method_name, *args, **kwargs)
211+
212+
if method_name.startswith("__") and method_name.endswith("__"):
213+
return self._generic_method_operator(method_name, *args, **kwargs)
152214

153215
return ArrayApi.generic_method(self, method_name, *args, **kwargs)

0 commit comments

Comments
 (0)