-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_first_example.py
135 lines (94 loc) · 3.67 KB
/
plot_first_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
.. _l-onnx-array-first-api-example:
First examples with onnx-array-api
==================================
This demonstrates an easy case with :epkg:`onnx-array-api`.
It shows how a function can be easily converted into
ONNX.
A loss function from numpy to ONNX
++++++++++++++++++++++++++++++++++
The first example takes a loss function and converts it into ONNX.
"""
import numpy as np
from onnx_array_api.npx import absolute, jit_onnx
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
################################
# The function looks like a numpy function.
def l1_loss(x, y):
return absolute(x - y).sum()
################################
# The function needs to be converted into ONNX with function jit_onnx.
# jitted_l1_loss is a wrapper. It intercepts all calls to l1_loss.
# When it happens, it checks the input types and creates the
# corresponding ONNX graph.
jitted_l1_loss = jit_onnx(l1_loss)
################################
# First execution and conversion to ONNX.
# The wrapper caches the created onnx graph.
# It reuses it if the input types and the number of dimension are the same.
# It creates a new one otherwise and keep the old one.
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_l1_loss(x, y)
print(res)
####################################
# The ONNX graph can be accessed the following way.
print(onnx_simple_text_plot(jitted_l1_loss.get_onnx()))
################################
# We can also define a more complex loss by computing L1 loss on
# the first column and L2 loss on the seconde one.
def l1_loss(x, y):
return absolute(x - y).sum()
def l2_loss(x, y):
return ((x - y) ** 2).sum()
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
jitted_myloss = jit_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
res = jitted_myloss(x, y)
print(res)
print(onnx_simple_text_plot(jitted_myloss.get_onnx()))
############################
# Eager mode
# ++++++++++
import numpy as np
from onnx_array_api.npx import absolute, eager_onnx
def l1_loss(x, y):
"""
err is a type inheriting from
:class:`EagerTensor <onnx_array_api.npx.npx_tensors.EagerTensor>`.
It needs to be converted to numpy first before any display.
"""
err = absolute(x - y).sum()
print(f"l1_loss={err.numpy()}")
return err
def l2_loss(x, y):
err = ((x - y) ** 2).sum()
print(f"l2_loss={err.numpy()}")
return err
def myloss(x, y):
return l1_loss(x[:, 0], y[:, 0]) + l2_loss(x[:, 1], y[:, 1])
#################################
# Eager mode is enabled by function :func:`eager_onnx
# <onnx_array_api.npx.npx_jit_eager.eager_onnx>`.
# It intercepts all calls to `my_loss`. On the first call,
# it replaces a numpy array by a tensor corresponding to the
# selected runtime, here numpy as well through
# :class:`EagerNumpyTensor
# <onnx_array_api.npx.npx_numpy_tensors.EagerNumpyTensor>`.
eager_myloss = eager_onnx(myloss)
x = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
y = np.array([[0.11, 0.22], [0.33, 0.44]], dtype=np.float32)
#################################
# First execution and conversion to ONNX.
# The wrapper caches many Onnx graphs corresponding to
# simple opeator, (`+`, `-`, `/`, `*`, ...), reduce functions,
# any other function from the API.
# It reuses it if the input types and the number of dimension are the same.
# It creates a new one otherwise and keep the old ones.
res = eager_myloss(x, y)
print(res)
################################
# There is no ONNX graph to show. Every operation
# is converted into small ONNX graphs.