Skip to content

Commit b758cb7

Browse files
committed
add logging
1 parent d166d81 commit b758cb7

File tree

3 files changed

+130
-3
lines changed

3 files changed

+130
-3
lines changed

_unittests/ut_npx/test_sklearn_array_api.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,7 @@ def test_sklearn_array_api_linear_discriminant(self):
2727

2828

2929
if __name__ == "__main__":
30+
import logging
31+
32+
logging.basicConfig(level=logging.DEBUG)
3033
unittest.main(verbosity=2)

onnx_array_api/npx/npx_jit_eager.py

+124-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from inspect import signature
2+
from logging import getLogger
23
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
34

45
import numpy as np
@@ -7,6 +8,8 @@
78
from .npx_types import TensorType
89
from .npx_var import Cst, Input, Var
910

11+
logger = getLogger("onnx-array-api")
12+
1013

1114
class JitEager:
1215
"""
@@ -40,6 +43,46 @@ def __init__(
4043
self.target_opsets = tensor_class.get_opsets(target_opsets)
4144
self.output_types = output_types
4245
self.ir_version = tensor_class.get_ir_version(ir_version)
46+
# parameters necessary after the function was converting to
47+
# onnx to remember an input in fact a mandatory parameter.
48+
self.n_inputs_ = 0
49+
self.input_to_kwargs_ = None
50+
self.method_name_ = None
51+
52+
def info(self, prefix: Optional[str] = None, method_name: Optional[str] = None):
53+
"""
54+
Logs a status.
55+
"""
56+
if prefix is None:
57+
logger.info("")
58+
return
59+
logger.info(
60+
"%s [%s.%s] nx=%d ni=%d kw=%d f=%s.%s cl=%s me=%s",
61+
prefix,
62+
self.__class__.__name__,
63+
method_name[:6],
64+
len(self.onxs),
65+
self.n_inputs_,
66+
0 if self.input_to_kwargs_ is None else 1,
67+
self.f.__module__,
68+
self.f.__name__,
69+
self.tensor_class.__name__,
70+
self.method_name_ or "",
71+
)
72+
73+
def status(self, me: str) -> str:
74+
"""
75+
Returns a short string indicating the status.
76+
"""
77+
return (
78+
f"[{self.__class__.__name__}.{me[:6]}]"
79+
f"nx={len(self.onxs)} "
80+
f"ni={self.n_inputs_} "
81+
f"kw={0 if self.input_to_kwargs_ is None else 1} "
82+
f"f={self.f.__module__}.{self.f.__name__} "
83+
f"cl={self.tensor_class.__name__} "
84+
f"me={self.method_name_ or ''}"
85+
)
4386

4487
@property
4588
def n_versions(self):
@@ -141,8 +184,10 @@ def to_jit(self, *values, **kwargs):
141184
The onnx graph built by the function defines the input
142185
types and the expected number of dimensions.
143186
"""
187+
self.info("+", "to_jit")
144188
annotations = self.f.__annotations__
145189
if len(annotations) > 0:
190+
input_to_kwargs = {}
146191
names = list(annotations.keys())
147192
annot_values = list(annotations.values())
148193
constraints = {}
@@ -154,6 +199,32 @@ def to_jit(self, *values, **kwargs):
154199
constraints[iname] = v.tensor_type_dims
155200
else:
156201
new_kwargs[iname] = v
202+
input_to_kwargs[i] = iname
203+
if self.input_to_kwargs_ is None:
204+
self.n_inputs_ = len(values) - len(input_to_kwargs)
205+
self.input_to_kwargs_ = input_to_kwargs
206+
elif self.input_to_kwargs_ != input_to_kwargs:
207+
raise RuntimeError(
208+
f"Unexpected input and argument. Previous call produced "
209+
f"self.input_to_kwargs_={self.input_to_kwargs_} and "
210+
f"input_to_kwargs={input_to_kwargs} for function {self.f} "
211+
f"from module {self.f.__module__!r}."
212+
)
213+
elif self.input_to_kwargs_:
214+
constraints = {}
215+
new_kwargs = {}
216+
for i, (v, iname) in enumerate(zip(values, names)):
217+
if (
218+
isinstance(v, (EagerTensor, JitTensor))
219+
and (
220+
i >= len(annot_values)
221+
or issubclass(annot_values[i], TensorType)
222+
)
223+
and i not in self.input_to_kwargs_
224+
):
225+
constraints[iname] = v.tensor_type_dims
226+
else:
227+
new_kwargs[iname] = v
157228
else:
158229
names = [f"x{i}" for i in range(len(values))]
159230
new_kwargs = {}
@@ -162,6 +233,8 @@ def to_jit(self, *values, **kwargs):
162233
for i, (v, iname) in enumerate(zip(values, names))
163234
if isinstance(v, (EagerTensor, JitTensor))
164235
}
236+
self.n_inputs_ = len(values)
237+
self.input_to_kwargs_ = {}
165238

166239
if self.output_types is not None:
167240
constraints.update(self.output_types)
@@ -187,6 +260,7 @@ def to_jit(self, *values, **kwargs):
187260
ir_version=self.ir_version,
188261
)
189262
exe = self.tensor_class.create_function(names, onx)
263+
self.info("-", "to_jit")
190264
return onx, exe
191265

192266
def cast_to_tensor_class(self, inputs: List[Any]) -> List[EagerTensor]:
@@ -221,6 +295,35 @@ def cast_from_tensor_class(
221295
return tuple(r.value for r in results)
222296
return results.value
223297

298+
def move_input_to_kwargs(
299+
self, values: List[Any], kwargs: Dict[str, Any]
300+
) -> Tuple[List[Any], Dict[str, Any]]:
301+
"""
302+
Mandatory parameters not usually not named. Some inputs must
303+
be moved to the parameter list before calling ONNX.
304+
305+
:param values: list of inputs
306+
:param kwargs: dictionary of arguments
307+
:return: new values, new arguments
308+
"""
309+
if self.input_to_kwargs_ is None:
310+
if self.bypass_eager or self.f.__annotations__:
311+
return values, kwargs
312+
raise RuntimeError(
313+
f"self.input_to_kwargs_ is not initialized for function {self.f} "
314+
f"from module {self.f.__module__!r}."
315+
)
316+
if len(self.input_to_kwargs_) == 0:
317+
return values, kwargs
318+
new_values = []
319+
new_kwargs = kwargs.copy()
320+
for i, v in enumerate(values):
321+
if i in self.input_to_kwargs_:
322+
new_kwargs[self.input_to_kwargs_[i]] = v
323+
else:
324+
new_values.append(values)
325+
return new_values, new_kwargs
326+
224327
def jit_call(self, *values, **kwargs):
225328
"""
226329
The method builds a key which identifies the signature
@@ -230,7 +333,13 @@ def jit_call(self, *values, **kwargs):
230333
indexed by the previous key. Finally, it executes the onnx graph
231334
and returns the result or the results in a tuple if there are several.
232335
"""
336+
self.info("+", "jit_call")
337+
values, kwargs = self.move_input_to_kwargs(values, kwargs)
233338
key = self.make_key(*values, **kwargs)
339+
if self.method_name_ is None and "method_name" in key:
340+
pos = list(key).index("method_name")
341+
self.method_name_ = key[pos + 1]
342+
234343
if key in self.versions:
235344
fct = self.versions[key]
236345
else:
@@ -243,8 +352,11 @@ def jit_call(self, *values, **kwargs):
243352
raise RuntimeError(
244353
f"Unable to run function for key={key!r}, "
245354
f"types={[type(x) for x in values]}, "
246-
f"kwargs={kwargs}, onnx={self.onxs[key]}."
355+
f"kwargs={kwargs}, "
356+
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
357+
f"onnx={self.onxs[key]}."
247358
) from e
359+
self.info("-", "jit_call")
248360
return res
249361

250362

@@ -297,9 +409,12 @@ def __call__(self, *args, **kwargs):
297409
The method first wraps the inputs with `self.tensor_class`
298410
and converts them into python types just after.
299411
"""
412+
self.info("+", "__call__")
300413
values = self.cast_to_tensor_class(args)
301414
res = self.jit_call(*values, **kwargs)
302-
return self.cast_from_tensor_class(res)
415+
res = self.cast_from_tensor_class(res)
416+
self.info("-", "jit_call")
417+
return res
303418

304419

305420
class EagerOnnx(JitEager):
@@ -316,6 +431,9 @@ class EagerOnnx(JitEager):
316431
the onnx graph is created and type is needed to do such,
317432
if not specified, the class assumes there is only one output
318433
of the same type as the input
434+
:param bypass_eager: this parameter must be true if the function
435+
has not annotation and is not decorated by `xapi_inline` or
436+
`xapi_function`
319437
:param ir_version: defines the IR version to use
320438
"""
321439

@@ -393,6 +511,8 @@ def __call__(self, *args, already_eager=False, **kwargs):
393511
:param already_eager: already in eager mode, inputs must be of type
394512
EagerTensor and the returned outputs must be the same
395513
"""
514+
self.info()
515+
self.info("+", "__call__")
396516
if already_eager:
397517
if any(
398518
map(
@@ -418,6 +538,7 @@ def __call__(self, *args, already_eager=False, **kwargs):
418538
# The function was already converted into onnx
419539
# reuse it or create a new one for different types.
420540
res = self.jit_call(*values, **kwargs)
541+
self.info("-", "1__call__")
421542
else:
422543
# tries to call the version
423544
try:
@@ -445,6 +566,7 @@ def __call__(self, *args, already_eager=False, **kwargs):
445566
# to be converted into onnx.
446567
res = self.jit_call(*values, **kwargs)
447568
self._eager_cache = True
569+
self.info("-", "2__call__")
448570
if already_eager:
449571
return tuple(res)
450572
return self.cast_from_tensor_class(res)

onnx_array_api/npx/npx_tensors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def _getitem_impl_var(obj, index, method_name=None):
7272
return meth(obj, index)
7373

7474
@staticmethod
75-
def _astype_impl(x, dtype, method_name=None):
75+
def _astype_impl(x, dtype: int = None, method_name=None):
7676
# avoids circular imports.
77+
if dtype is None:
78+
raise ValueError("dtype cannot be None.")
7779
from .npx_var import Var
7880

7981
if not isinstance(x, Var):

0 commit comments

Comments
 (0)