-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathinline_function.py
70 lines (61 loc) · 3.12 KB
/
inline_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.fx import Proxy, symbolic_trace
from torch.fx.node import map_arg
'''
How to Inline a Function Into an Existing Graph
One reason you might want to inline a function is to get around FX's
default tracing behavior. For example, unless you've defined a custom
Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes
references to ``torch.nn`` module instances to appear as
``call_module`` calls rather than being traced through. Let's say this
behavior is almost what you need; the only problem is that there's a
single module call that you want to replace with an inlined trace of the
function. Creating a custom Tracer would be too much. Instead, you can
accomplish this using Proxies.
The following code demonstrates how to trace a module and inline it
into an existing Graph using Proxy. We'll trace our Graph, then iterate
through its Nodes until we find the right place to swap out the
``call_module`` Node with an inlined trace. At that point, we'll create
Proxies from the Node's args and kwargs. Finally, we'll call the
function we want to replace with those Proxies--which will, in essence,
"trace" that function. Finally, we'll insert the result of that call
into our Graph. (This last step will automatically inline the function.)
'''
# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x) + 1.0
# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())
# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
# create a graph-appending tracer pointing to the original graph
tracer = torch.fx.proxy.GraphAppendingTracer(m.graph)
for node in m.graph.nodes:
# Find `call_module` Node in `m` that corresponds to `self.relu`.
# This is the Node we want to swap out for an inlined version of the
# same call
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# Create a Proxy from each Node in the current Node's
# args/kwargs
proxy_args = map_arg(node.args, lambda n: Proxy(n, tracer))
proxy_kwargs = map_arg(node.kwargs, lambda n: Proxy(n, tracer))
# Call `m.relu` with the newly-created Proxy arguments.
# `m.relu` is the generic version of the function; by
# calling it with Proxies created from Nodes in `m`, we're
# emitting Nodes that reference exiting values in the IR.
# The result of this call is another Proxy, which we can
# hook into our existing Graph to complete the function
# inlining.
proxy_output = m.relu(*proxy_args, **proxy_kwargs)
# Replace the relu `call_module` node with the inlined
# version of the function
node.replace_all_uses_with(proxy_output.node)
# Make sure that the old relu Node is erased
m.graph.erase_node(node)