-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoffloading.py
177 lines (141 loc) · 5.33 KB
/
offloading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import Sequence
import torch.fx as fx
import torch
import torch_xla
from torch.utils._pytree import tree_iter
from functorch.compile import aot_function, make_boxed_func # type:ignore
from .remat_all import remat_all_partition_fn
@torch.library.custom_op("xla::offload_name", mutates_args=())
def offload_name(t: torch.Tensor, name: str) -> torch.Tensor:
"""
`offload_name` is an identity function that associates the input
tensor with `name`. It is primarily useful in conjunction with
`remat_all_and_offload_these_inputs`, which will rematerialize
intermediate activations and also offload inputs with the specified
names to host memory, moving them back during the backward pass.
"""
if t is None:
return None
return t.clone()
@offload_name.register_fake
def _(t: torch.Tensor, name: str) -> torch.Tensor:
if t is None:
return None
return torch.empty_like(t)
def offload_name_backward(ctx, grad):
return grad, None
offload_name.register_autograd(offload_name_backward)
def remat_all_and_offload_these_inputs(
joint_module: fx.GraphModule,
_joint_inputs,
*,
num_fwd_outputs,
names_to_offload: Sequence[str],
):
"""
`remat_all_and_offload_these_inputs` will rematerialize (recompute) all
intermediate activations in `joint_module`, and offload inputs with the
specified names to host memory, moving them back during the backward pass.
It transforms the joint graph into separate forward and backward graphs.
"""
input_device = next(iter(tree_iter(_joint_inputs))).device
fwd, bwd = remat_all_partition_fn(
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
with torch.device(input_device):
fw_example_args = make_arguments(fwd)
bw_example_args = make_arguments(bwd)
fw_name_in_output_indices = get_name_in_output_indices(fwd)
bw_name_in_input_names = get_name_in_input_names(bwd)
for name in names_to_offload:
print(f"Going to offload {name}")
assert name in fw_name_in_output_indices
assert name in bw_name_in_input_names
print("fw_name_in_output_indices", fw_name_in_output_indices)
print("bw_name_in_input_names", bw_name_in_input_names)
with torch.no_grad():
def forward(**kwargs):
import pdb
try:
out = fwd(**kwargs)
indices_to_offload = set(
[fw_name_in_output_indices[name] for name in names_to_offload])
return tuple(
torch.ops.xla.place_to_host(v) if i in # type:ignore
indices_to_offload else v for i, v in enumerate(out))
except Exception:
pdb.post_mortem()
def backward(**kwargs):
print(f"Backward got {len(kwargs)} arguments:")
for k, v in kwargs.items():
print(f"Arg {k}: {v.shape if v is not None else 'None'}")
arguments_to_move_back = set(
[bw_name_in_input_names[name] for name in names_to_offload])
kwargs = {
k: torch.ops.xla.place_to_device(v) # type: ignore
if k in arguments_to_move_back else v for k, v in kwargs.items()
}
import pdb
try:
values = bwd(**kwargs)
print(f"Backward will return {len(values)} values:")
for i, v in enumerate(values):
print(f"Arg {i}: {v.shape if v is not None else 'None'}")
return values
except Exception:
pdb.post_mortem()
# Use AOTAutograd to retrace forward and backward, thus incorporating
# the offloading ops.
graph = [None]
def get_graph(g, _):
graph[0] = g
return make_boxed_func(g)
_ = aot_function(forward, fw_compiler=get_graph)(**fw_example_args)
aot_forward = graph[0]
_ = aot_function(backward, fw_compiler=get_graph)(**bw_example_args)
aot_backward = graph[0]
return aot_forward, aot_backward
def make_arguments(gm: fx.GraphModule):
"""
Given a graph module, `make_arguments` returns a dictionary of example inputs
that can be used as keyward arguments to call the graph module.
"""
example_args = {}
for node in gm.graph.nodes:
if node.op != 'placeholder':
continue
if 'tensor_meta' in node.meta:
tensor_meta = node.meta['tensor_meta']
tensor = torch.zeros(
tensor_meta.shape,
dtype=tensor_meta.dtype,
requires_grad=tensor_meta.requires_grad)
example_args[node.name] = tensor
return example_args
def get_named_nodes(gm: torch.fx.GraphModule):
named_nodes = {}
for node in gm.graph.nodes:
if node.op == "call_function":
if hasattr(node.target, "name"):
if node.target.name() == offload_name._qualname: # type: ignore
named_nodes[node.args[0]] = node.args[1]
return named_nodes
def get_name_in_output_indices(gm: torch.fx.GraphModule):
named_nodes = get_named_nodes(gm)
name_in_output_indices = {}
for node in gm.graph.nodes:
if node.op == "output":
assert len(node.args) <= 1
if len(node.args) == 0:
continue
for i, arg in enumerate(next(iter(node.args))): # type: ignore
if arg in named_nodes:
name_in_output_indices[named_nodes[arg]] = i
return name_in_output_indices
def get_name_in_input_names(gm: torch.fx.GraphModule):
named_nodes = get_named_nodes(gm)
name_in_input_names = {}
for node in gm.graph.nodes:
if node.op == "placeholder":
if node in named_nodes:
name_in_input_names[named_nodes[node]] = node.target
return name_in_input_names