1
1
from inspect import signature
2
+ from logging import getLogger
2
3
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
3
4
4
5
import numpy as np
7
8
from .npx_types import TensorType
8
9
from .npx_var import Cst , Input , Var
9
10
11
+ logger = getLogger ("onnx-array-api" )
12
+
10
13
11
14
class JitEager :
12
15
"""
@@ -40,6 +43,46 @@ def __init__(
40
43
self .target_opsets = tensor_class .get_opsets (target_opsets )
41
44
self .output_types = output_types
42
45
self .ir_version = tensor_class .get_ir_version (ir_version )
46
+ # parameters necessary after the function was converting to
47
+ # onnx to remember an input in fact a mandatory parameter.
48
+ self .n_inputs_ = 0
49
+ self .input_to_kwargs_ = None
50
+ self .method_name_ = None
51
+
52
+ def info (self , prefix : Optional [str ] = None , method_name : Optional [str ] = None ):
53
+ """
54
+ Logs a status.
55
+ """
56
+ if prefix is None :
57
+ logger .info ("" )
58
+ return
59
+ logger .info (
60
+ "%s [%s.%s] nx=%d ni=%d kw=%d f=%s.%s cl=%s me=%s" ,
61
+ prefix ,
62
+ self .__class__ .__name__ ,
63
+ method_name [:6 ],
64
+ len (self .onxs ),
65
+ self .n_inputs_ ,
66
+ 0 if self .input_to_kwargs_ is None else 1 ,
67
+ self .f .__module__ ,
68
+ self .f .__name__ ,
69
+ self .tensor_class .__name__ ,
70
+ self .method_name_ or "" ,
71
+ )
72
+
73
+ def status (self , me : str ) -> str :
74
+ """
75
+ Returns a short string indicating the status.
76
+ """
77
+ return (
78
+ f"[{ self .__class__ .__name__ } .{ me [:6 ]} ]"
79
+ f"nx={ len (self .onxs )} "
80
+ f"ni={ self .n_inputs_ } "
81
+ f"kw={ 0 if self .input_to_kwargs_ is None else 1 } "
82
+ f"f={ self .f .__module__ } .{ self .f .__name__ } "
83
+ f"cl={ self .tensor_class .__name__ } "
84
+ f"me={ self .method_name_ or '' } "
85
+ )
43
86
44
87
@property
45
88
def n_versions (self ):
@@ -141,8 +184,10 @@ def to_jit(self, *values, **kwargs):
141
184
The onnx graph built by the function defines the input
142
185
types and the expected number of dimensions.
143
186
"""
187
+ self .info ("+" , "to_jit" )
144
188
annotations = self .f .__annotations__
145
189
if len (annotations ) > 0 :
190
+ input_to_kwargs = {}
146
191
names = list (annotations .keys ())
147
192
annot_values = list (annotations .values ())
148
193
constraints = {}
@@ -154,6 +199,32 @@ def to_jit(self, *values, **kwargs):
154
199
constraints [iname ] = v .tensor_type_dims
155
200
else :
156
201
new_kwargs [iname ] = v
202
+ input_to_kwargs [i ] = iname
203
+ if self .input_to_kwargs_ is None :
204
+ self .n_inputs_ = len (values ) - len (input_to_kwargs )
205
+ self .input_to_kwargs_ = input_to_kwargs
206
+ elif self .input_to_kwargs_ != input_to_kwargs :
207
+ raise RuntimeError (
208
+ f"Unexpected input and argument. Previous call produced "
209
+ f"self.input_to_kwargs_={ self .input_to_kwargs_ } and "
210
+ f"input_to_kwargs={ input_to_kwargs } for function { self .f } "
211
+ f"from module { self .f .__module__ !r} ."
212
+ )
213
+ elif self .input_to_kwargs_ :
214
+ constraints = {}
215
+ new_kwargs = {}
216
+ for i , (v , iname ) in enumerate (zip (values , names )):
217
+ if (
218
+ isinstance (v , (EagerTensor , JitTensor ))
219
+ and (
220
+ i >= len (annot_values )
221
+ or issubclass (annot_values [i ], TensorType )
222
+ )
223
+ and i not in self .input_to_kwargs_
224
+ ):
225
+ constraints [iname ] = v .tensor_type_dims
226
+ else :
227
+ new_kwargs [iname ] = v
157
228
else :
158
229
names = [f"x{ i } " for i in range (len (values ))]
159
230
new_kwargs = {}
@@ -162,6 +233,8 @@ def to_jit(self, *values, **kwargs):
162
233
for i , (v , iname ) in enumerate (zip (values , names ))
163
234
if isinstance (v , (EagerTensor , JitTensor ))
164
235
}
236
+ self .n_inputs_ = len (values )
237
+ self .input_to_kwargs_ = {}
165
238
166
239
if self .output_types is not None :
167
240
constraints .update (self .output_types )
@@ -187,6 +260,7 @@ def to_jit(self, *values, **kwargs):
187
260
ir_version = self .ir_version ,
188
261
)
189
262
exe = self .tensor_class .create_function (names , onx )
263
+ self .info ("-" , "to_jit" )
190
264
return onx , exe
191
265
192
266
def cast_to_tensor_class (self , inputs : List [Any ]) -> List [EagerTensor ]:
@@ -221,6 +295,35 @@ def cast_from_tensor_class(
221
295
return tuple (r .value for r in results )
222
296
return results .value
223
297
298
+ def move_input_to_kwargs (
299
+ self , values : List [Any ], kwargs : Dict [str , Any ]
300
+ ) -> Tuple [List [Any ], Dict [str , Any ]]:
301
+ """
302
+ Mandatory parameters not usually not named. Some inputs must
303
+ be moved to the parameter list before calling ONNX.
304
+
305
+ :param values: list of inputs
306
+ :param kwargs: dictionary of arguments
307
+ :return: new values, new arguments
308
+ """
309
+ if self .input_to_kwargs_ is None :
310
+ if self .bypass_eager or self .f .__annotations__ :
311
+ return values , kwargs
312
+ raise RuntimeError (
313
+ f"self.input_to_kwargs_ is not initialized for function { self .f } "
314
+ f"from module { self .f .__module__ !r} ."
315
+ )
316
+ if len (self .input_to_kwargs_ ) == 0 :
317
+ return values , kwargs
318
+ new_values = []
319
+ new_kwargs = kwargs .copy ()
320
+ for i , v in enumerate (values ):
321
+ if i in self .input_to_kwargs_ :
322
+ new_kwargs [self .input_to_kwargs_ [i ]] = v
323
+ else :
324
+ new_values .append (values )
325
+ return new_values , new_kwargs
326
+
224
327
def jit_call (self , * values , ** kwargs ):
225
328
"""
226
329
The method builds a key which identifies the signature
@@ -230,7 +333,13 @@ def jit_call(self, *values, **kwargs):
230
333
indexed by the previous key. Finally, it executes the onnx graph
231
334
and returns the result or the results in a tuple if there are several.
232
335
"""
336
+ self .info ("+" , "jit_call" )
337
+ values , kwargs = self .move_input_to_kwargs (values , kwargs )
233
338
key = self .make_key (* values , ** kwargs )
339
+ if self .method_name_ is None and "method_name" in key :
340
+ pos = list (key ).index ("method_name" )
341
+ self .method_name_ = key [pos + 1 ]
342
+
234
343
if key in self .versions :
235
344
fct = self .versions [key ]
236
345
else :
@@ -243,8 +352,11 @@ def jit_call(self, *values, **kwargs):
243
352
raise RuntimeError (
244
353
f"Unable to run function for key={ key !r} , "
245
354
f"types={ [type (x ) for x in values ]} , "
246
- f"kwargs={ kwargs } , onnx={ self .onxs [key ]} ."
355
+ f"kwargs={ kwargs } , "
356
+ f"self.input_to_kwargs_={ self .input_to_kwargs_ } , "
357
+ f"onnx={ self .onxs [key ]} ."
247
358
) from e
359
+ self .info ("-" , "jit_call" )
248
360
return res
249
361
250
362
@@ -297,9 +409,12 @@ def __call__(self, *args, **kwargs):
297
409
The method first wraps the inputs with `self.tensor_class`
298
410
and converts them into python types just after.
299
411
"""
412
+ self .info ("+" , "__call__" )
300
413
values = self .cast_to_tensor_class (args )
301
414
res = self .jit_call (* values , ** kwargs )
302
- return self .cast_from_tensor_class (res )
415
+ res = self .cast_from_tensor_class (res )
416
+ self .info ("-" , "jit_call" )
417
+ return res
303
418
304
419
305
420
class EagerOnnx (JitEager ):
@@ -316,6 +431,9 @@ class EagerOnnx(JitEager):
316
431
the onnx graph is created and type is needed to do such,
317
432
if not specified, the class assumes there is only one output
318
433
of the same type as the input
434
+ :param bypass_eager: this parameter must be true if the function
435
+ has not annotation and is not decorated by `xapi_inline` or
436
+ `xapi_function`
319
437
:param ir_version: defines the IR version to use
320
438
"""
321
439
@@ -393,6 +511,8 @@ def __call__(self, *args, already_eager=False, **kwargs):
393
511
:param already_eager: already in eager mode, inputs must be of type
394
512
EagerTensor and the returned outputs must be the same
395
513
"""
514
+ self .info ()
515
+ self .info ("+" , "__call__" )
396
516
if already_eager :
397
517
if any (
398
518
map (
@@ -418,6 +538,7 @@ def __call__(self, *args, already_eager=False, **kwargs):
418
538
# The function was already converted into onnx
419
539
# reuse it or create a new one for different types.
420
540
res = self .jit_call (* values , ** kwargs )
541
+ self .info ("-" , "1__call__" )
421
542
else :
422
543
# tries to call the version
423
544
try :
@@ -445,6 +566,7 @@ def __call__(self, *args, already_eager=False, **kwargs):
445
566
# to be converted into onnx.
446
567
res = self .jit_call (* values , ** kwargs )
447
568
self ._eager_cache = True
569
+ self .info ("-" , "2__call__" )
448
570
if already_eager :
449
571
return tuple (res )
450
572
return self .cast_from_tensor_class (res )
0 commit comments