|
1 | 1 | import unittest
|
2 | 2 | from contextlib import redirect_stdout
|
3 | 3 | from io import StringIO
|
4 |
| - |
5 | 4 | import numpy as np
|
6 | 5 | from onnx.defs import onnx_opset_version
|
7 | 6 | from onnx.reference import ReferenceEvaluator
|
8 | 7 | from onnxruntime import InferenceSession
|
9 |
| - |
10 | 8 | from onnx_array_api.ext_test_case import ExtTestCase
|
11 | 9 | from onnx_array_api.npx import eager_onnx, jit_onnx
|
12 | 10 | from onnx_array_api.npx.npx_functions import absolute as absolute_inline
|
13 | 11 | from onnx_array_api.npx.npx_functions import cdist as cdist_inline
|
14 | 12 | from onnx_array_api.npx.npx_functions_test import absolute
|
15 |
| -from onnx_array_api.npx.npx_types import Float32, Float64 |
| 13 | +from onnx_array_api.npx.npx_functions import copy as copy_inline |
| 14 | +from onnx_array_api.npx.npx_types import Float32, Float64, DType |
16 | 15 | from onnx_array_api.npx.npx_var import Input
|
17 | 16 | from onnx_array_api.ort.ort_tensors import EagerOrtTensor, JitOrtTensor, OrtTensor
|
18 | 17 |
|
@@ -193,6 +192,49 @@ def impl(xa, xb):
|
193 | 192 | if len(pieces) > 2:
|
194 | 193 | raise AssertionError(f"Function is not using argument:\n{onx}")
|
195 | 194 |
|
| 195 | + def test_astype(self): |
| 196 | + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) |
| 197 | + onx = f.to_onnx(constraints={"A": Float64[None]}) |
| 198 | + x = np.array([[-5, 6]], dtype=np.float64) |
| 199 | + z = np.abs(x.astype(np.float32)) |
| 200 | + ref = InferenceSession( |
| 201 | + onx.SerializeToString(), providers=["CPUExecutionProvider"] |
| 202 | + ) |
| 203 | + got = ref.run(None, {"A": x}) |
| 204 | + self.assertEqualArray(z, got[0]) |
| 205 | + |
| 206 | + def test_astype0(self): |
| 207 | + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) |
| 208 | + onx = f.to_onnx(constraints={"A": Float64[None]}) |
| 209 | + x = np.array(-5, dtype=np.float64) |
| 210 | + z = np.abs(x.astype(np.float32)) |
| 211 | + ref = InferenceSession( |
| 212 | + onx.SerializeToString(), providers=["CPUExecutionProvider"] |
| 213 | + ) |
| 214 | + got = ref.run(None, {"A": x}) |
| 215 | + self.assertEqualArray(z, got[0]) |
| 216 | + |
| 217 | + def test_eager_ort_cast(self): |
| 218 | + def impl(A): |
| 219 | + return A.astype(DType("FLOAT")) |
| 220 | + |
| 221 | + e = eager_onnx(impl) |
| 222 | + self.assertEqual(len(e.versions), 0) |
| 223 | + |
| 224 | + # Float64 |
| 225 | + x = np.array([0, 1, -2], dtype=np.float64) |
| 226 | + z = x.astype(np.float32) |
| 227 | + res = e(x) |
| 228 | + self.assertEqualArray(z, res) |
| 229 | + self.assertEqual(res.dtype, np.float32) |
| 230 | + |
| 231 | + # again |
| 232 | + x = np.array(1, dtype=np.float64) |
| 233 | + z = x.astype(np.float32) |
| 234 | + res = e(x) |
| 235 | + self.assertEqualArray(z, res) |
| 236 | + self.assertEqual(res.dtype, np.float32) |
| 237 | + |
196 | 238 |
|
197 | 239 | if __name__ == "__main__":
|
198 | 240 | # TestNpx().test_eager_numpy()
|
|
0 commit comments