Skip to content

Commit d875d0d

Browse files
committed
unit test documentation examples
1 parent 4f5b006 commit d875d0d

File tree

5 files changed

+79
-3
lines changed

5 files changed

+79
-3
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ _doc/auto_examples/*
1212
_doc/examples/plot_*.png
1313
_doc/_static/require.js
1414
_doc/_static/viz.js
15+
_unittests/ut__main/*.png

README.rst

+11
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ well as to execute it.
5959
6060
print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
6161
62+
::
63+
64+
[0.042]
65+
opset: domain='' version=18
66+
input: name='x0' type=dtype('float32') shape=['', '']
67+
input: name='x1' type=dtype('float32') shape=['', '']
68+
Sub(x0, x1) -> r__0
69+
Abs(r__0) -> r__1
70+
ReduceSum(r__1, keepdims=0) -> r__2
71+
output: name='r__2' type=dtype('float32') shape=None
72+
6273
It supports eager mode as well:
6374

6475
.. code-block:: python

_doc/examples/plot_first_example.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
First examples with onnx-array-api
66
==================================
77
8-
This demonstrates easy case with :epkg:`onnx-array-api`.
8+
This demonstrates an easy case with :epkg:`onnx-array-api`.
9+
It shows how a function can be easily converted into
10+
ONNX.
911
1012
A loss function from numpy to ONNX
1113
++++++++++++++++++++++++++++++++++
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import unittest
2+
import os
3+
import sys
4+
import importlib
5+
import subprocess
6+
import time
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
9+
10+
def import_source(module_file_path, module_name):
11+
if not os.path.exists(module_file_path):
12+
raise FileNotFoundError(module_file_path)
13+
module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
14+
if module_spec is None:
15+
raise FileNotFoundError(
16+
"Unable to find '{}' in '{}'.".format(module_name, module_file_path)
17+
)
18+
module = importlib.util.module_from_spec(module_spec)
19+
return module_spec.loader.exec_module(module)
20+
21+
22+
class TestDocumentationExamples(ExtTestCase):
23+
def test_documentation_examples(self):
24+
this = os.path.abspath(os.path.dirname(__file__))
25+
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
26+
found = os.listdir(fold)
27+
tested = 0
28+
for name in found:
29+
if name.startswith("plot_") and name.endswith(".py"):
30+
perf = time.perf_counter()
31+
try:
32+
mod = import_source(fold, os.path.splitext(name)[0])
33+
assert mod is not None
34+
except FileNotFoundError:
35+
# try another way
36+
cmds = [sys.executable, "-u", os.path.join(fold, name)]
37+
p = subprocess.Popen(
38+
cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE
39+
)
40+
res = p.communicate()
41+
out, err = res
42+
st = err.decode("ascii", errors="ignore")
43+
if len(st) > 0 and "Traceback" in st:
44+
if '"dot" not found in path.' in st:
45+
# dot not installed, this part
46+
# is tested in onnx framework
47+
print(f"failed: {name!r} due to missing dot.")
48+
continue
49+
raise AssertionError(
50+
"Example '{}' (cmd: {} - exec_prefix='{}') "
51+
"failed due to\n{}"
52+
"".format(name, cmds, sys.exec_prefix, st)
53+
)
54+
dt = time.perf_counter() - perf
55+
print(f"{dt:.3f}: run {name!r}")
56+
tested += 1
57+
if tested == 0:
58+
raise AssertionError("No example was tested.")
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

_unittests/ut_npx/test_npx.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2418,7 +2418,7 @@ def l2_loss(x, y):
24182418

24192419
def test_eager_cst_index(self):
24202420
def l1_loss(x, y):
2421-
return absolute(x - y).sum()
2421+
return absolute_inline(x - y).sum()
24222422

24232423
def l2_loss(x, y):
24242424
return ((x - y) ** 2).sum()
@@ -2441,5 +2441,5 @@ def myloss(x, y):
24412441

24422442

24432443
if __name__ == "__main__":
2444-
# TestNpx().test_eager_numpy()
2444+
TestNpx().test_eager_cst_index()
24452445
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)