1
1
from typing import Any
2
2
3
+ import numpy as np
4
+ from onnx .helper import np_dtype_to_tensor_dtype
5
+
3
6
from .npx_array_api import ArrayApi
4
7
5
8
@@ -59,6 +62,16 @@ def _getitem_impl_var(obj, index, method_name=None):
59
62
meth = getattr (Var , method_name )
60
63
return meth (obj , index )
61
64
65
+ @staticmethod
66
+ def _astype_impl (x , dtype , method_name = None ):
67
+ # avoids circular imports.
68
+ from .npx_var import Var
69
+
70
+ if not isinstance (x , Var ):
71
+ raise TypeError (f"Input 0 must be a Var not { type (x )} ." )
72
+ meth = getattr (Var , "astype" )
73
+ return meth (x , dtype )
74
+
62
75
@staticmethod
63
76
def _getitem_impl_tuple (obj , index = None , method_name = None ):
64
77
# avoids circular imports.
@@ -69,13 +82,114 @@ def _getitem_impl_tuple(obj, index=None, method_name=None):
69
82
meth = getattr (Var , method_name )
70
83
return meth (obj , index )
71
84
85
+ def _generic_method_getitem (self , method_name , * args : Any , ** kwargs : Any ) -> Any :
86
+ # avoids circular imports.
87
+ from .npx_jit_eager import eager_onnx
88
+
89
+ if len (args ) != 1 :
90
+ raise ValueError (
91
+ f"Unexpected number of argument { len (args )} , it should be one."
92
+ )
93
+ if isinstance (args [0 ], tuple ):
94
+ eag = eager_onnx (
95
+ EagerTensor ._getitem_impl_tuple , self .__class__ , bypass_eager = True
96
+ )
97
+ res = eag (self , index = args [0 ], method_name = method_name , already_eager = True )
98
+ else :
99
+ eag = eager_onnx (
100
+ EagerTensor ._getitem_impl_var , self .__class__ , bypass_eager = True
101
+ )
102
+ res = eag (self , args [0 ], method_name = method_name , already_eager = True )
103
+ if isinstance (res , tuple ) and len (res ) == 1 :
104
+ return res [0 ]
105
+ return res
106
+
107
+ def _generic_method_operator (self , method_name , * args : Any , ** kwargs : Any ) -> Any :
108
+ # avoids circular imports.
109
+ from .npx_jit_eager import eager_onnx
110
+
111
+ if len (args ) not in (0 , 1 ):
112
+ raise ValueError (
113
+ f"An operator must have zero or one argument not { len (args )} ."
114
+ )
115
+ if len (kwargs ) not in (0 , 1 ):
116
+ raise ValueError (f"Operators do not support parameters { len (kwargs )} ." )
117
+
118
+ # let's cast numpy arrays into constants.
119
+ new_args = []
120
+ for a in args :
121
+ if isinstance (a , np .ndarray ):
122
+ new_args .append (self .__class__ (a ).astype (self .dtype ))
123
+ else :
124
+ new_args .append (a )
125
+
126
+ eag = eager_onnx (EagerTensor ._op_impl , self .__class__ , bypass_eager = True )
127
+ res = eag (self , * new_args , method_name = method_name , already_eager = True )
128
+ if isinstance (res , tuple ) and len (res ) == 1 :
129
+ return res [0 ]
130
+ return res
131
+
132
+ def _generic_method_reduce (self , method_name , * args : Any , ** kwargs : Any ) -> Any :
133
+ # avoids circular imports.
134
+ from .npx_jit_eager import eager_onnx
135
+
136
+ if len (args ) not in (0 , 1 ):
137
+ raise ValueError (
138
+ f"An operator must have zero or one argument not { len (args )} ."
139
+ )
140
+
141
+ if "axis" in kwargs :
142
+ axes = kwargs ["axis" ]
143
+ del kwargs ["axis" ]
144
+ else :
145
+ axes = None
146
+ if axes is None :
147
+ eag = eager_onnx (
148
+ EagerTensor ._reduce_impl_noaxes , self .__class__ , bypass_eager = True
149
+ )
150
+ res = eag (self , method_name = method_name , already_eager = True , ** kwargs )
151
+ else :
152
+ eag = eager_onnx (
153
+ EagerTensor ._reduce_impl , self .__class__ , bypass_eager = True
154
+ )
155
+ res = eag (self , axes , method_name = method_name , already_eager = True , ** kwargs )
156
+ if isinstance (res , tuple ) and len (res ) == 1 :
157
+ return res [0 ]
158
+ return res
159
+
160
+ @staticmethod
161
+ def _np_dtype_to_tensor_dtype (dtype ):
162
+ if dtype == int :
163
+ dtype = np .dtype ("int64" )
164
+ elif dtype == float :
165
+ dtype = np .dtype ("float64" )
166
+ return np_dtype_to_tensor_dtype (dtype )
167
+
168
+ def _generic_method_astype (self , method_name , * args : Any , ** kwargs : Any ) -> Any :
169
+ # avoids circular imports.
170
+ from .npx_jit_eager import eager_onnx
171
+ from .npx_var import Var
172
+
173
+ if len (args ) != 1 :
174
+ raise ValueError (f"astype takes only one argument not { len (args )} ." )
175
+
176
+ dtype = (
177
+ args [0 ]
178
+ if isinstance (args [0 ], (int , Var ))
179
+ else self ._np_dtype_to_tensor_dtype (args [0 ])
180
+ )
181
+ eag = eager_onnx (EagerTensor ._astype_impl , self .__class__ , bypass_eager = True )
182
+ res = eag (self , dtype , method_name = method_name , already_eager = True , ** kwargs )
183
+ if isinstance (res , tuple ) and len (res ) == 1 :
184
+ return res [0 ]
185
+ return res
186
+
72
187
def generic_method (self , method_name , * args : Any , ** kwargs : Any ) -> Any :
73
188
"""
74
189
The method converts the method into an ONNX graph build by the
75
190
corresponding method in class Var.
76
191
"""
77
192
# avoids circular imports.
78
- from .npx_jit_eager import eager_onnx
79
193
from .npx_var import Var
80
194
81
195
if not hasattr (Var , method_name ):
@@ -84,70 +198,18 @@ def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any:
84
198
f"This method cannot be converted into an ONNX graph."
85
199
)
86
200
if method_name == "__getitem__" :
87
- if len (args ) != 1 :
88
- raise ValueError (
89
- f"Unexpected number of argument { len (args )} , it should be one."
90
- )
91
- if isinstance (args [0 ], tuple ):
92
- eag = eager_onnx (
93
- EagerTensor ._getitem_impl_tuple , self .__class__ , bypass_eager = True
94
- )
95
- res = eag (
96
- self , index = args [0 ], method_name = method_name , already_eager = True
97
- )
98
- else :
99
- eag = eager_onnx (
100
- EagerTensor ._getitem_impl_var , self .__class__ , bypass_eager = True
101
- )
102
- res = eag (self , args [0 ], method_name = method_name , already_eager = True )
103
- if isinstance (res , tuple ) and len (res ) == 1 :
104
- return res [0 ]
105
- return res
201
+ return self ._generic_method_getitem (method_name , * args , ** kwargs )
106
202
107
203
if method_name == "__setitem__" :
108
204
return ArrayApi .generic_method (self , method_name , * args , ** kwargs )
109
205
110
- if method_name .startswith ("__" ) and method_name .endswith ("__" ):
111
- # An operator.
112
- if len (args ) not in (0 , 1 ):
113
- raise ValueError (
114
- f"An operator must have zero or one argument not { len (args )} ."
115
- )
116
- if len (kwargs ) not in (0 , 1 ):
117
- raise ValueError (f"Operators do not support parameters { len (kwargs )} ." )
118
-
119
- eag = eager_onnx (EagerTensor ._op_impl , self .__class__ , bypass_eager = True )
120
- res = eag (self , * args , method_name = method_name , already_eager = True )
121
- if isinstance (res , tuple ) and len (res ) == 1 :
122
- return res [0 ]
123
- return res
124
-
125
206
if method_name in {"mean" , "sum" , "min" , "max" , "prod" }:
126
- # ReduceFunction
127
- if len (args ) not in (0 , 1 ):
128
- raise ValueError (
129
- f"An operator must have zero or one argument not { len (args )} ."
130
- )
131
-
132
- if "axis" in kwargs :
133
- axes = kwargs ["axis" ]
134
- del kwargs ["axis" ]
135
- else :
136
- axes = None
137
- if axes is None :
138
- eag = eager_onnx (
139
- EagerTensor ._reduce_impl_noaxes , self .__class__ , bypass_eager = True
140
- )
141
- res = eag (self , method_name = method_name , already_eager = True , ** kwargs )
142
- else :
143
- eag = eager_onnx (
144
- EagerTensor ._reduce_impl , self .__class__ , bypass_eager = True
145
- )
146
- res = eag (
147
- self , axes , method_name = method_name , already_eager = True , ** kwargs
148
- )
149
- if isinstance (res , tuple ) and len (res ) == 1 :
150
- return res [0 ]
151
- return res
207
+ return self ._generic_method_reduce (method_name , * args , ** kwargs )
208
+
209
+ if method_name == "astype" :
210
+ return self ._generic_method_astype (method_name , * args , ** kwargs )
211
+
212
+ if method_name .startswith ("__" ) and method_name .endswith ("__" ):
213
+ return self ._generic_method_operator (method_name , * args , ** kwargs )
152
214
153
215
return ArrayApi .generic_method (self , method_name , * args , ** kwargs )
0 commit comments