Skip to content

Commit 1bfd39f

Browse files
committed
still an issue with astype
1 parent 49d5641 commit 1bfd39f

File tree

5 files changed

+32
-4
lines changed

5 files changed

+32
-4
lines changed

onnx_array_api/npx/npx_core_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def wrapper(*inputs, **kwargs):
123123
for x in inputs:
124124
if isinstance(x, EagerTensor):
125125
tensor_class = x.__class__
126+
break
126127
if tensor_class is None:
127128
raise RuntimeError(
128129
f"Unable to find an EagerTensor in types "

onnx_array_api/npx/npx_functions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ def astype(
184184
"""
185185
Cast an array.
186186
"""
187-
g = a.astype(dtype)
188-
return g
187+
if isinstance(dtype, Var):
188+
raise TypeError(
189+
f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}."
190+
)
191+
return var(a, op="Cast", to=dtype)
189192

190193

191194
@npxapi_inline

onnx_array_api/npx/npx_jit_eager.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def make_key(*values, **kwargs):
9292
res.append(v)
9393
elif isinstance(v, slice):
9494
res.append(("slice", v.start, v.stop, v.step))
95+
elif isinstance(v, type):
96+
res.append(("type", v.__name__))
9597
elif isinstance(v, tuple):
9698
subkey = []
9799
for sk in v:
@@ -139,16 +141,23 @@ def to_jit(self, *values, **kwargs):
139141
The onnx graph built by the function defines the input
140142
types and the expected number of dimensions.
141143
"""
144+
annotations = self.f.__annotations__
145+
annot_values = list(annotations.values())
142146
constraints = {
143147
f"x{i}": v.tensor_type_dims
144148
for i, v in enumerate(values)
145149
if isinstance(v, (EagerTensor, JitTensor))
150+
and (i >= len(annot_values) or issubclass(annot_values[i], TensorType))
146151
}
147152

148153
if self.output_types is not None:
149154
constraints.update(self.output_types)
150155

151-
inputs = [Input(f"x{i}") for i in range(len(values))]
156+
inputs = [Input(f"x{i}") for i in range(len(values)) if f"x{i}" in constraints]
157+
if len(inputs) < len(values):
158+
# An attribute is not named in the numpy API
159+
# but is the ONNX definition.
160+
raise NotImplementedError()
152161
var = self.f(*inputs, **kwargs)
153162

154163
onx = var.to_onnx(

onnx_array_api/npx/npx_numpy_tensors.py

+6
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def get_ir_version(cls, ir_version):
161161
"""
162162
return ir_version
163163

164+
def const_cast(self, to: Any = None) -> "EagerTensor":
165+
"""
166+
Casts a constant without any ONNX conversion.
167+
"""
168+
return self.__class__(self._tensor.astype(to))
169+
164170
# The class should support whatever Var supports.
165171
# This part is not yet complete.
166172

onnx_array_api/npx/npx_tensors.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ class EagerTensor(ArrayApi):
2121
:class:`ArrayApi`.
2222
"""
2323

24+
def const_cast(self, to: Any = None) -> "EagerTensor":
25+
"""
26+
Casts a constant without any ONNX conversion.
27+
"""
28+
raise NotImplementedError(
29+
f"Method 'const_cast' must be overwritten in class "
30+
f"{self.__class__.__name__!r}."
31+
)
32+
2433
@staticmethod
2534
def _op_impl(*inputs, method_name=None):
2635
# avoids circular imports.
@@ -119,7 +128,7 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An
119128
new_args = []
120129
for a in args:
121130
if isinstance(a, np.ndarray):
122-
new_args.append(self.__class__(a).astype(self.dtype))
131+
new_args.append(self.__class__(a).const_cast(self.dtype))
123132
else:
124133
new_args.append(a)
125134

0 commit comments

Comments
 (0)