Skip to content

Commit 35f7e88

Browse files
authored
Documentation, add one more API (#52)
1 parent d7d4e2e commit 35f7e88

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

_doc/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
"inner API": "https://onnx.ai/onnx/intro/python.html",
121121
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
122122
"onnx": "https://onnx.ai/onnx/",
123+
"onnx-graphsurgeon": "https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/index.html",
123124
"onnx.helper": "https://onnx.ai/onnx/api/helper.html",
124125
"ONNX": "https://onnx.ai/",
125126
"ONNX Operators": "https://onnx.ai/onnx/operators/",

_doc/tutorial/onnx_api.rst

+83-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ onnx syntax. :epkg:`scikit-learn` is implemented with :epkg:`numpy` and there
1414
is no converter from numpy to onnx. Sometimes, it is needed to extend
1515
an existing onnx models or to merge models coming from different packages.
1616
Sometimes, they are just not available, only onnx is.
17-
Let's see how it looks like a very simply example.
17+
Let's see how it looks like with a very simply example.
1818

1919
Euclidian distance
2020
==================
@@ -263,12 +263,10 @@ A couple of examples.
263263
264264
model = MyModel()
265265
kwargs = {"bias": 3.}
266-
args = (torch.randn(2, 2, 2),)
266+
inputs = (torch.randn(2, 2, 2),)
267267
268-
export_output = torch.onnx.dynamo_export(
269-
model,
270-
*args,
271-
**kwargs).save("my_simple_model.onnx")
268+
export_output = torch.onnx.dynamo_export(model, inputs, **kwargs)
269+
export_output.save("my_simple_model.onnx")
272270
273271
.. code-block:: python
274272
@@ -462,6 +460,7 @@ onnxblocks
462460
`onnxblocks <https://onnxruntime.ai/docs/api/python/on_device_training/training_artifacts.html#prepare-for-training>`_
463461
was introduced in onnxruntime to define custom losses in order to train
464462
a model with :epkg:`onnxruntime-training`. It is mostly used for this usage.
463+
The syntax is similar to pytorch.
465464

466465
.. code-block:: python
467466
@@ -507,6 +506,84 @@ a model with :epkg:`onnxruntime-training`. It is mostly used for this usage.
507506
# Successful completion of the above call will generate 4 files in the current working directory,
508507
# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, op)
509508
509+
ONNX GraphSurgeon
510+
+++++++++++++++++
511+
512+
:epkg:`onnx-graphsurgeon` implements main class `Graph` which provides
513+
all the necessary method to add nodes, import existing onnx files.
514+
The following example is taken from `onnx-graphsurgeon/examples
515+
<https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon/examples>`_.
516+
The first part generates a graph.
517+
518+
.. code-block:: python
519+
520+
import onnx_graphsurgeon as gs
521+
import numpy as np
522+
import onnx
523+
524+
# Computes Y = x0 + (a * x1 + b)
525+
526+
shape = (1, 3, 224, 224)
527+
# Inputs
528+
x0 = gs.Variable(name="x0", dtype=np.float32, shape=shape)
529+
x1 = gs.Variable(name="x1", dtype=np.float32, shape=shape)
530+
531+
# Intermediate tensors
532+
a = gs.Constant("a", values=np.ones(shape=shape, dtype=np.float32))
533+
b = gs.Constant("b", values=np.ones(shape=shape, dtype=np.float32))
534+
mul_out = gs.Variable(name="mul_out")
535+
add_out = gs.Variable(name="add_out")
536+
537+
# Outputs
538+
Y = gs.Variable(name="Y", dtype=np.float32, shape=shape)
539+
540+
nodes = [
541+
# mul_out = a * x1
542+
gs.Node(op="Mul", inputs=[a, x1], outputs=[mul_out]),
543+
# add_out = mul_out + b
544+
gs.Node(op="Add", inputs=[mul_out, b], outputs=[add_out]),
545+
# Y = x0 + add
546+
gs.Node(op="Add", inputs=[x0, add_out], outputs=[Y]),
547+
]
548+
549+
graph = gs.Graph(nodes=nodes, inputs=[x0, x1], outputs=[Y])
550+
onnx.save(gs.export_onnx(graph), "model.onnx")
551+
552+
The second part modifies it.
553+
554+
.. code-block:: python
555+
556+
import onnx_graphsurgeon as gs
557+
import numpy as np
558+
import onnx
559+
560+
graph = gs.import_onnx(onnx.load("model.onnx"))
561+
562+
# 1. Remove the `b` input of the add node
563+
first_add = [node for node in graph.nodes if node.op == "Add"][0]
564+
first_add.inputs = [inp for inp in first_add.inputs if inp.name != "b"]
565+
566+
# 2. Change the Add to a LeakyRelu
567+
first_add.op = "LeakyRelu"
568+
first_add.attrs["alpha"] = 0.02
569+
570+
# 3. Add an identity after the add node
571+
identity_out = gs.Variable("identity_out", dtype=np.float32)
572+
identity = gs.Node(op="Identity", inputs=first_add.outputs, outputs=[identity_out])
573+
graph.nodes.append(identity)
574+
575+
# 4. Modify the graph output to be the identity output
576+
graph.outputs = [identity_out]
577+
578+
# 5. Remove unused nodes/tensors, and topologically sort the graph
579+
# ONNX requires nodes to be topologically sorted to be considered valid.
580+
# Therefore, you should only need to sort the graph when you have added new nodes out-of-order.
581+
# In this case, the identity node is already in the correct spot (it is the last node,
582+
# and was appended to the end of the list), but to be on the safer side, we can sort anyway.
583+
graph.cleanup().toposort()
584+
585+
onnx.save(gs.export_onnx(graph), "modified.onnx")
586+
510587
numpy API for onnx
511588
++++++++++++++++++
512589

0 commit comments

Comments
 (0)