2
2
import unittest
3
3
import numpy as np
4
4
from onnx import TensorProto
5
- from onnx_array_api .ext_test_case import ExtTestCase
5
+ from onnx_array_api .ext_test_case import ExtTestCase , ignore_warnings
6
6
from onnx_array_api .array_api import onnx_numpy as xp
7
7
from onnx_array_api .npx .npx_types import DType
8
8
from onnx_array_api .npx .npx_numpy_tensors import EagerNumpyTensor as EagerTensor
9
+ from onnx_array_api .npx .npx_functions import linspace as linspace_inline
10
+ from onnx_array_api .npx .npx_types import Float64 , Int64
11
+ from onnx_array_api .npx .npx_var import Input
12
+ from onnx_array_api .reference import ExtendedReferenceEvaluator
9
13
10
14
11
15
class TestOnnxNumpy (ExtTestCase ):
@@ -22,6 +26,7 @@ def test_zeros(self):
22
26
a = xp .absolute (mat )
23
27
self .assertEqualArray (np .absolute (mat .numpy ()), a .numpy ())
24
28
29
+ @ignore_warnings (DeprecationWarning )
25
30
def test_arange_default (self ):
26
31
a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
27
32
b = EagerTensor (np .array ([2 ], dtype = np .int64 ))
@@ -30,6 +35,7 @@ def test_arange_default(self):
30
35
self .assertEqual (matnp .shape , (2 ,))
31
36
self .assertEqualArray (matnp , np .arange (0 , 2 ).astype (np .int64 ))
32
37
38
+ @ignore_warnings (DeprecationWarning )
33
39
def test_arange_step (self ):
34
40
a = EagerTensor (np .array ([4 ], dtype = np .int64 ))
35
41
s = EagerTensor (np .array ([2 ], dtype = np .int64 ))
@@ -78,6 +84,7 @@ def test_full_bool(self):
78
84
self .assertNotEmpty (matnp [0 , 0 ])
79
85
self .assertEqualArray (matnp , np .full ((4 , 5 ), False ))
80
86
87
+ @ignore_warnings (DeprecationWarning )
81
88
def test_arange_int00a (self ):
82
89
a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
83
90
b = EagerTensor (np .array ([0 ], dtype = np .int64 ))
@@ -89,6 +96,7 @@ def test_arange_int00a(self):
89
96
expected = expected .astype (np .int64 )
90
97
self .assertEqualArray (matnp , expected )
91
98
99
+ @ignore_warnings (DeprecationWarning )
92
100
def test_arange_int00 (self ):
93
101
mat = xp .arange (0 , 0 )
94
102
matnp = mat .numpy ()
@@ -160,10 +168,94 @@ def test_eye_k(self):
160
168
got = xp .eye (nr , k = 1 )
161
169
self .assertEqualArray (expected , got .numpy ())
162
170
171
+ def test_linspace_int (self ):
172
+ a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
173
+ b = EagerTensor (np .array ([6 ], dtype = np .int64 ))
174
+ c = EagerTensor (np .array (3 , dtype = np .int64 ))
175
+ mat = xp .linspace (a , b , c )
176
+ matnp = mat .numpy ()
177
+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ()).astype (np .int64 )
178
+ self .assertEqualArray (expected , matnp )
179
+
180
+ def test_linspace_int5 (self ):
181
+ a = EagerTensor (np .array ([0 ], dtype = np .int64 ))
182
+ b = EagerTensor (np .array ([5 ], dtype = np .int64 ))
183
+ c = EagerTensor (np .array (3 , dtype = np .int64 ))
184
+ mat = xp .linspace (a , b , c )
185
+ matnp = mat .numpy ()
186
+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ()).astype (np .int64 )
187
+ self .assertEqualArray (expected , matnp )
188
+
189
+ def test_linspace_float (self ):
190
+ a = EagerTensor (np .array ([0.5 ], dtype = np .float64 ))
191
+ b = EagerTensor (np .array ([5.5 ], dtype = np .float64 ))
192
+ c = EagerTensor (np .array (2 , dtype = np .int64 ))
193
+ mat = xp .linspace (a , b , c )
194
+ matnp = mat .numpy ()
195
+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy ())
196
+ self .assertEqualArray (expected , matnp )
197
+
198
+ def test_linspace_float_noendpoint (self ):
199
+ a = EagerTensor (np .array ([0.5 ], dtype = np .float64 ))
200
+ b = EagerTensor (np .array ([5.5 ], dtype = np .float64 ))
201
+ c = EagerTensor (np .array (2 , dtype = np .int64 ))
202
+ mat = xp .linspace (a , b , c , endpoint = 0 )
203
+ matnp = mat .numpy ()
204
+ expected = np .linspace (a .numpy (), b .numpy (), c .numpy (), endpoint = 0 )
205
+ self .assertEqualArray (expected , matnp )
206
+
207
+ @ignore_warnings ((RuntimeWarning , DeprecationWarning )) # division by zero
208
+ def test_linspace_zero (self ):
209
+ expected = np .linspace (0.0 , 0.0 , 0 , endpoint = False )
210
+ mat = xp .linspace (0.0 , 0.0 , 0 , endpoint = False )
211
+ matnp = mat .numpy ()
212
+ self .assertEqualArray (expected , matnp )
213
+
214
+ @ignore_warnings ((RuntimeWarning , DeprecationWarning )) # division by zero
215
+ def test_linspace_zero_one (self ):
216
+ expected = np .linspace (0.0 , 0.0 , 1 , endpoint = True )
217
+
218
+ f = linspace_inline (Input ("start" ), Input ("stop" ), Input ("num" ))
219
+ onx = f .to_onnx (
220
+ constraints = {
221
+ "start" : Float64 [None ],
222
+ "stop" : Float64 [None ],
223
+ "num" : Int64 [None ],
224
+ (0 , False ): Float64 [None ],
225
+ }
226
+ )
227
+ ref = ExtendedReferenceEvaluator (onx )
228
+ got = ref .run (
229
+ None ,
230
+ {
231
+ "start" : np .array (0 , dtype = np .float64 ),
232
+ "stop" : np .array (0 , dtype = np .float64 ),
233
+ "num" : np .array (1 , dtype = np .int64 ),
234
+ },
235
+ )
236
+ self .assertEqualArray (expected , got [0 ])
237
+
238
+ mat = xp .linspace (0.0 , 0.0 , 1 , endpoint = True )
239
+ matnp = mat .numpy ()
240
+
241
+ self .assertEqualArray (expected , matnp )
242
+
243
+ def test_slice_minus_one (self ):
244
+ g = EagerTensor (np .array ([0.0 ]))
245
+ expected = g .numpy ()[:- 1 ]
246
+ got = g [:- 1 ]
247
+ self .assertEqualArray (expected , got .numpy ())
248
+
249
+ def test_linspace_bug1 (self ):
250
+ expected = np .linspace (16777217.0 , 0.0 , 1 )
251
+ mat = xp .linspace (16777217.0 , 0.0 , 1 )
252
+ matnp = mat .numpy ()
253
+ self .assertEqualArray (expected , matnp )
254
+
163
255
164
256
if __name__ == "__main__" :
165
257
# import logging
166
258
167
259
# logging.basicConfig(level=logging.DEBUG)
168
- TestOnnxNumpy ().test_eye ()
260
+ TestOnnxNumpy ().test_linspace_float_noendpoint ()
169
261
unittest .main (verbosity = 2 )
0 commit comments