@@ -92,6 +92,8 @@ def make_key(*values, **kwargs):
92
92
res .append (v )
93
93
elif isinstance (v , slice ):
94
94
res .append (("slice" , v .start , v .stop , v .step ))
95
+ elif isinstance (v , type ):
96
+ res .append (("type" , v .__name__ ))
95
97
elif isinstance (v , tuple ):
96
98
subkey = []
97
99
for sk in v :
@@ -139,16 +141,23 @@ def to_jit(self, *values, **kwargs):
139
141
The onnx graph built by the function defines the input
140
142
types and the expected number of dimensions.
141
143
"""
144
+ annotations = self .f .__annotations__
145
+ annot_values = list (annotations .values ())
142
146
constraints = {
143
147
f"x{ i } " : v .tensor_type_dims
144
148
for i , v in enumerate (values )
145
149
if isinstance (v , (EagerTensor , JitTensor ))
150
+ and (i >= len (annot_values ) or issubclass (annot_values [i ], TensorType ))
146
151
}
147
152
148
153
if self .output_types is not None :
149
154
constraints .update (self .output_types )
150
155
151
- inputs = [Input (f"x{ i } " ) for i in range (len (values ))]
156
+ inputs = [Input (f"x{ i } " ) for i in range (len (values )) if f"x{ i } " in constraints ]
157
+ if len (inputs ) < len (values ):
158
+ # An attribute is not named in the numpy API
159
+ # but is the ONNX definition.
160
+ raise NotImplementedError ()
152
161
var = self .f (* inputs , ** kwargs )
153
162
154
163
onx = var .to_onnx (
0 commit comments