-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathoutput_graph.py
1844 lines (1617 loc) · 71.8 KB
/
output_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import collections
import contextlib
import copy
import functools
import itertools
import logging
import operator
import re
import sys
import traceback
import weakref
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
import sympy
import torch._guards
import torch._logging
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
from torch._guards import (
Checkpointable,
GlobalContextCheckpointState,
GuardsCheckpointState,
Source,
TracingContext,
)
from torch._utils_internal import signpost_event
from torch.fx.experimental.sym_node import SymNode
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.reference import PythonReferenceAnalysis
from torch.utils.weak import WeakTensorKeyDictionary
from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
create_call_function,
create_instruction,
Instruction,
unique_id,
)
from .code_context import code_context
from .codegen import PyCodegen
from .current_scope_id import enter_new_scope
from .exc import (
BackendCompilerFailed,
exceptions_allowed_to_be_fallback,
SkipFrame,
unimplemented,
unimplemented_with_warning,
)
from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
AttrSource,
ConstantSource,
GlobalStateSource,
is_constant_source,
is_from_local_source,
LocalSource,
ParamBufferSource,
ShapeEnvSource,
TensorProperty,
TensorPropertySource,
)
from .utils import (
checkpoint_params,
CleanupHook,
clone_inputs,
count_calls,
counters,
dynamo_timed,
get_instruction_source_311,
get_static_address_type,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
lazy_format_graph_tabular,
LazyString,
same,
)
from .variables.base import VariableTracker
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
)
from .variables.torch_function import TensorWithTFOverrideVariable
log = logging.getLogger(__name__)
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
class OutputGraphState(NamedTuple):
input_source_to_var: Dict[Source, VariableTracker]
tracked_fakes: List[TrackedFake]
guard_state: GuardsCheckpointState
nn_modules: Optional[Dict[str, torch.nn.Module]]
register_finalizer_fns: List[Callable[[fx.GraphModule], None]]
global_state: Optional[Dict[str, bool]]
param_name_to_source: Optional[Dict[str, Source]]
side_effects: SideEffects
timestamp: int
non_compliant_ops: Set[torch._ops.OpOverload]
def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
for k in self._fields:
if k == "guard_state":
r = self.guard_state.diff(other.guard_state)
if r is not None:
return r
continue
elif k == "side_effects":
r = self.side_effects.diff(other.side_effects)
if r is not None:
return r
continue
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{prefix}{k} mismatch: {sv} != {ov}"
return None
# Back compat .guards api
@property
def guards(self):
return self.guard_state.dynamo_guards
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclass
class GraphCompileReason:
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
reason: str
user_stack: List[traceback.FrameSummary]
# Indicates if this was a graph compile reason due to graph break.
graph_break: bool = True
def __post_init__(self):
if self.graph_break:
graph_break_reasons.append(self)
def _get_gen_rand_values_fn(random_calls):
def _gen_rand_values():
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
return _gen_rand_values
class FakeRootModule(torch.nn.Module):
"""Trick the constructor of fx.GraphModule"""
def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
super().__init__()
for k, v in nn_modules.items():
setattr(self, k, v)
def __repr__(self):
return "FakeRootModule(...)"
class WrapperBackend:
def __init__(self, backend: CompilerFn):
self.backend: CompilerFn = backend
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.restore = checkpoint_params(gm)
self.gm = gm
copy_gm = copy.deepcopy(self.gm)
self.candidate = self.backend(copy_gm, example_inputs)
if self.candidate is None or self.candidate is self.gm.forward:
return self.gm.forward
if not config.verify_correctness:
return self.candidate
# if verify_correctness=True
try:
correct = self.gm.forward(*clone_inputs(example_inputs))
result = self.candidate(*clone_inputs(example_inputs))
# TODO: replace `same` function with the one in testing
if same(correct, result):
return self.candidate
raise RuntimeError(f"incorrect results of backend {self}")
return self.gm.forward
except Exception:
log.exception("error in verify_correctness")
raise
finally:
self.restore()
Scope = Dict[str, object]
class OutputGraph(Checkpointable[OutputGraphState]):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
OutputGraph is 1:1 with a frame being processed. Each frame is associated
with some root InstructionTranslator. When user code calls a function,
we construct a InliningInstructionTranslator that continues to write into
the root InstructionTranslator's OutputGraph.
"""
def __init__(
self,
code_options: Dict[str, Any],
compiler_fn: Optional[CompilerFn],
root_tx,
export: bool,
export_constraints,
frame_state,
local_scope: Scope,
global_scope: Scope,
f_code,
):
super().__init__()
self.tracers = [SubgraphTracer(self, export_root=export)]
# Map from graph input's `Source` to its `VariableTracker` to
# de-duplicate graph inputs by source and reuse the tracker
self.input_source_to_var: Dict[Source, VariableTracker] = {}
self.export = export
self.export_constraints = export_constraints
self.frame_state = frame_state
self.tensor_weakref_to_sizes_strides = WeakTensorKeyDictionary()
self.cleanup_hooks: List[Callable[[], Any]] = []
# TODO: maybe should just pass the entire f_code in here? Not
# sure...
self.co_fields = {
"co_name": f_code.co_name,
"co_filename": f_code.co_filename,
"co_firstlineno": f_code.co_firstlineno,
}
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
# GraphArgs that got pruned, and things like Tensor attributes which
# aren't explicit graph inputs. Used by shape guard
self.tracked_fakes: List[TrackedFake] = []
# List of symbols for which we have exact bindings in the arguments
# already
self.bound_symbols: Set[sympy.Symbol] = set()
shape_env = ShapeEnv(
# Reference Cycle!
# Share a reference to the list of TrackedFake.
#
# ShapeEnv needs this in order to be able to reproduce the call
# to produce_guards at an arbitrary time point. That is because
# TrackedFake instances may have its metadata changed throughout
# the program execution.
tracked_fakes=self.tracked_fakes,
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
co_fields=self.co_fields,
)
# In export mode, we force the shape_env to strictly disallow any constraining
# of the user marked dynamic dims
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=shape_env,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
self.init_ambient_guards()
# Map each tensor id to a list of sources. This is necessary because
# tensor ids cannot be recovered from tracked fakes (in general).
# We use this map to interpret (i.e., check for violations of) constraints,
# specifically equality constraints, which have shared tensor ids in them.
# This map should also be generally useful, e.g., for (de)serialization.
self.tracked_fakes_id_to_source: Dict[
int, List[Source]
] = collections.defaultdict(list)
# Stores the full fqn of a param or buffer to the relevant source.
self.param_name_to_source: Optional[Dict[str, Source]] = dict()
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions: List[Instruction] = []
# used to track nodes that are added between calls of copy_graphstate
# and restore_graphstate
self.timestamp = 0
# A list of register_finalizer_fns to apply to the output graph module
self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = []
# Not checkpointed
self.compiler_fn: Optional[CompilerFn] = compiler_fn
self.global_scope = global_scope
self.local_scope = local_scope
self.root_tx = root_tx
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
# Given a source, what are the user stacks of all locations that
# accessed it?
#
# For efficiency, we only populate this:
# - During export, and
# - If the source could potentially lead to a spurious export input
#
# Feel free to populate this more frequently if other use-cases arise,
# but be aware that we have to generate full stacks for each
# recording!
self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {}
self._current_tx: List[InstructionTranslatorBase] = []
self.cleanups: List[CleanupHook] = []
self.should_exit = False
self.random_values_var = None
self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
self.torch_function_enabled = torch._C._is_torch_function_enabled()
# Tracks if the output graph has a user defined allowed function in the
# graph. This is used later to determine if we should fallback to eager
# for certain exceptions. THe idea is that if the user has applied
# allow_in_graph, they would like to see the error instead of falling
# back for backend errors.
self.has_user_defined_allowed_in_graph = False
# Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
# This information is useful for logging.
self.non_compliant_ops: Set[torch._ops.OpOverload] = set({})
# We save the global torch state here to be restored in case of graph
# breaks. The relevant issue is seen here
# https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
# where inlining of a function changes the global state (because of the
# presence of torch.no_grad) and there is a graph break.
self.save_global_state()
# This gets its own helper function so guards DEBUG logs are more
# informative
def init_ambient_guards(self):
# Register a SHAPE_ENV guard to make sure we setup shape guards
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
self.guards.add(
GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
)
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
self.guards.add(
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
)
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.CONFIG_HASH_MATCH))
def guard_has_graph_break(self):
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.HAS_GRAPH_BREAK))
def add_cleanup_hook(self, fn: Callable[[], Any]):
self.cleanup_hooks.append(fn)
def call_cleanup_hooks(self):
for hook in reversed(self.cleanup_hooks):
hook()
self.cleanup_hooks.clear()
@property
def root_tracer(self):
return self.tracers[0]
@property
def current_tracer(self):
return self.tracers[-1]
def is_root_tracer(self):
# Helper to tell if we are inside the higher order operator tracing.
return len(self.tracers) == 1
@property
def graph(self):
return self.current_tracer.graph
# TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
@graph.setter
def graph(self, value):
self.current_tracer.graph = value
@property
def input_name_to_proxy(self):
return self.current_tracer.input_name_to_proxy
@property
def real_value_cache(self):
return self.current_tracer.real_value_cache
# If you are here, and you're looking for create_graph_input,
# to avoid ambiguity, please call one of the following:
# - self.current_tracer.create_graph_input
# - self.root_tracer.create_graph_input
# See NOTE [HigherOrderOperator tracing design] for more context.
def create_proxy(self, *args, **kwargs):
return self.current_tracer.create_proxy(*args, **kwargs)
def create_node(self, *args, **kwargs):
return self.current_tracer.create_node(*args, **kwargs)
def remove_node(self, *args, **kwargs):
return self.current_tracer.remove_node(*args, **kwargs)
@contextlib.contextmanager
def subtracer(self, source_target, prior_tracer):
new_scope_ctx = enter_new_scope()
try:
if prior_tracer:
# Lineage MUST stay preserved
assert prior_tracer.parent is self.current_tracer
new_scope_ctx.__enter__()
tracer = (
prior_tracer
if prior_tracer
else SubgraphTracer(
self, parent=self.current_tracer, source_target=source_target
)
)
self.tracers.append(tracer)
yield tracer
finally:
new_scope_ctx.__exit__(None, None, None)
self.tracers.pop()
@property
def output(self):
return self
@property
def fake_mode(self):
return self.tracing_context.fake_mode
@property
def shape_env(self):
return self.tracing_context.fake_mode.shape_env
@property
def guards(self) -> torch._guards.GuardsSet:
return self.tracing_context.guards_context.dynamo_guards
@property
def nn_modules(self) -> Dict[str, Any]:
return self.tracing_context.module_context.nn_modules
def save_global_state(self, out=None):
"""
Saves to out if it is provided. Else saves to the tracing context's global_state.
"""
global_state = (
out if out is not None else self.tracing_context.global_context.global_state
)
global_state["torch_function_enabled"] = (
self.set_torch_function_state,
self.torch_function_enabled,
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
global_state["autocast_enabled"] = (
torch.set_autocast_enabled,
torch.is_autocast_enabled(),
)
global_state["autocast_cpu_enabled"] = (
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled(),
)
global_state["autocast_gpu_dtype"] = (
torch.set_autocast_gpu_dtype,
torch.get_autocast_gpu_dtype(),
)
global_state["autocast_cpu_dtype"] = (
torch.set_autocast_cpu_dtype,
torch.get_autocast_cpu_dtype(),
)
global_state["autocast_cache_enabled"] = (
torch.set_autocast_cache_enabled,
torch.is_autocast_cache_enabled(),
)
def push_tx(self, tx):
self._current_tx.append(tx)
def pop_tx(self):
return self._current_tx.pop()
@property
def current_tx(self):
return self.root_tx if not self._current_tx else self._current_tx[-1]
def copy_graphstate(self) -> OutputGraphState:
"""Create a checkpoint of the current state by copying everything"""
assert self.param_name_to_source is not None
guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
module_state = self.tracing_context.module_context.copy_graphstate()
global_state = self.tracing_context.global_context.copy_graphstate()
state = OutputGraphState(
dict(self.input_source_to_var),
list(self.tracked_fakes),
guards_graph_state,
module_state,
list(self.register_finalizer_fns),
global_state,
dict(self.param_name_to_source),
self.side_effects.clone(),
self.timestamp,
set(self.non_compliant_ops),
)
self.timestamp += 1
return state
def restore_graphstate(self, state: OutputGraphState):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
self.input_source_to_var,
self.tracked_fakes,
guards_state,
module_state,
self.register_finalizer_fns,
global_state,
self.param_name_to_source,
self.side_effects,
self.timestamp,
self.non_compliant_ops,
) = state
self.tracing_context.guards_context.restore_graphstate(guards_state)
self.tracing_context.module_context.restore_graphstate(module_state)
self.tracing_context.global_context.restore_graphstate(global_state)
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0
for node in reversed(list(self.graph.nodes)):
if (
node.meta["creation_timestamp"] > self.timestamp
# placeholders here may have been lazily added by existing objects
and node.op != "placeholder"
):
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
del node.meta["example_value"]
self.remove_node(node)
self.real_value_cache.pop(node, None)
removed_nodes += 1
log.debug("restore_graphstate: removed %s nodes", removed_nodes)
def add_symbol_bindings(self, arg: GraphArg):
# Insert implicit size vars as necessary. With dynamic shapes, we
# maintain the invariant that every sizevar gets a direct SymInt input
# into the graph. This means downstream graph transforms can assume
# every size variable is explicitly bound and accessible, instead of
# having to pull it out implicitly from tensors.
if self.export:
return
assert arg.fake_tensor is not None
def bind_symint(s, prop):
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)):
return
s0 = s.node.expr
if s0 in self.bound_symbols:
return
self.bound_symbols.add(s0)
log.debug("bind_symint %s %s", s, prop.name())
# TODO: don't readd symint if we already have it in graph
# (this is harmless because we do remove the unused ones later)
proxy = self.root_tracer.create_graph_input(
str(s0),
torch.SymInt,
before=True,
source=prop,
)
proxy.node.meta["example_value"] = s
proxy.node.meta["grapharg"] = GraphArg(
prop,
s,
is_unspecialized=False,
fake_tensor=None,
is_tensor=False,
)
def handle_tensor(t, src):
for i, s in enumerate(t.size()):
bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i))
for i, s in enumerate(t.stride()):
bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i))
bind_symint(
t.storage_offset(),
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET),
)
if is_traceable_wrapper_subclass(t):
attrs, ctx = t.__tensor_flatten__()
for attr in attrs:
inner_t = getattr(t, attr)
handle_tensor(inner_t, AttrSource(src, attr))
handle_tensor(arg.fake_tensor, arg.source)
def count_calls(self):
return count_calls(self.graph)
def is_empty_graph(self):
return len(list(self.graph.nodes)) == 0
def get_submodule(self, keys):
assert keys
obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules
for k in keys.split("."):
if isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj
def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
for i in itertools.count():
var = f"{name}_{i}"
if var not in existing:
self.code_options["co_varnames"] += (var,)
return var
def update_co_names(self, name):
"""Ensure self.code_options.co_names contains name"""
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
@staticmethod
def module_key_name(*names):
# create a new unique name
name = "_".join(map(str, names))
# Strip the guard lookup L/G access
name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
# e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
if not name or not name[0].isalpha():
name = "sub" + name
return name
def register_attr_or_module(
self,
target: Union[torch.nn.Module, torch.Tensor, Any],
*names,
**options,
):
if is_dynamic_nn_module(target):
return variables.UnspecializedNNModuleVariable(target, **options)
options = dict(options)
assert "source" in options
source = options["source"]
assert not isinstance(source, ParamBufferSource)
if isinstance(target, torch.Tensor):
tracer = self.current_tracer
if not self.is_root_tracer():
# For higher order ops, we don't want to insert the get_attr in
# innermost graph. Instead, we want to raise the params/buffers
# as inputs to the higher-order graph, and register them as
# get_attrs in the root tracer.
# Note that Dynamo will still call lift_tracked_freevar_to_input
# when these inputs are encountered for the inner graph. The
# only difference is what happens at the root tracer for
# nn.Parameters vs free inputs. The free inputs are registered
# as placeholders in the root graph, whereas the nn.Parameters
# are registered as get_attr nodes in the root graph.
tracer = self.root_tracer
if not is_constant_source(source):
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
if get_static_address_type(target) == "guarded":
install_guard(source.make_guard(GuardBuilder.DATA_PTR_MATCH))
def wrap_name(module_key):
assert self.param_name_to_source is not None
self.param_name_to_source[module_key] = source
return wrap_fx_proxy(
self.root_tx,
tracer.create_proxy("get_attr", module_key, tuple(), {}),
example_value=target,
**options,
)
elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module)
install_guard(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key):
return NNModuleVariable(type(target), module_key, **options)
elif isinstance(target, (torch.SymInt, torch.SymFloat)):
# HACKY CODE REGION BEGIN
# WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
# This ultimately gets written to self.nn_modules, which is unfortunate
# Attrs that are tenors and symints and such need to be migrated to have their
# own storage
# alas, this is like this for now
def wrap_name(module_key):
return SymNodeVariable.create(
self,
self.create_proxy("get_attr", module_key, tuple(), {}),
sym_num=target,
**options,
)
# HACKY CODE REGION END
else:
def wrap_name(module_key):
self.output.update_co_names(module_key)
self.global_scope[module_key] = target
return VariableBuilder(self, ConstantSource(source_name=module_key))(
target
)
for k, v in self.nn_modules.items():
if v is target:
# it already exists
return wrap_name(k)
name = OutputGraph.module_key_name(*names)
base = name
for i in itertools.count():
if name not in self.nn_modules:
self.nn_modules[name] = target
if isinstance(target, torch.nn.Module):
def register_leaf_name(leaf_name):
assert self.param_name_to_source is not None
new_source = ParamBufferSource(source, leaf_name)
new_name = f"{name}.{leaf_name}"
self.param_name_to_source[new_name] = new_source
# annoying, but there are cases when we do not have parameters
# see test_nn_moduledict_contains
if hasattr(target, "_parameters"):
for leaf_name, _ in target.named_parameters():
register_leaf_name(leaf_name)
if hasattr(target, "_buffers"):
for leaf_name, _ in target.named_buffers():
register_leaf_name(leaf_name)
return wrap_name(name)
name = f"{base}_{i}"
raise AssertionError("unreachable")
def compile_subgraph(
self,
tx,
partial_convert=False,
reason: Optional[GraphCompileReason] = None,
compile_return_value=False,
):
"""
Generate a subgraph to continue execution on user code.
Automatically restore live variables.
"""
assert reason is not None
from .decorators import disable
self.partial_convert = partial_convert
self.compile_subgraph_reason = reason
self.should_exit = True
if not compile_return_value:
# invalid graph to be cache hit for nopython
self.guard_has_graph_break()
log.debug("COMPILING GRAPH due to %s", reason)
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")
prefix_insts: List[Instruction] = []
if sys.version_info >= (3, 11):
# prefix instructions (Python 3.11+)
for inst in tx.prefix_insts:
if inst.opname == "MAKE_CELL":
prefix_insts.append(
create_instruction("MAKE_CELL", argval=inst.argval)
)
elif inst.opname == "COPY_FREE_VARS":
prefix_insts.append(
create_instruction(
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
)
)
else:
prefix_insts.append(copy.copy(inst))
def append_prefix_insts():
self.add_output_instructions(prefix_insts)
prefix_insts.clear()
for block in reversed(tx.block_stack):
block.exit(tx)
self.cleanup_graph()
tx.prune_dead_locals()
stack_values = list(tx.stack)
root = FakeRootModule(self.nn_modules)
# Add all the local vars to the "stack" so restore at the end
restore_vars = []
val_to_names: Dict[VariableTracker, List[str]] = {}
if stack_values:
val_to_names[stack_values[-1]] = list()
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
# symbolic_locals will be empty at this point, as prune_dead_locals
# will clear out all of symbolic_locals because RETURN_VALUE is the
# last instruction and no more locals are used. The fanciness here
# is only needed for partial graphs.
for k, v in tx.symbolic_locals.items():
# Note! this explicitly uses .local_name for matching
# Failure to do so will cause spurious registrations in val_to_names.
# This will in turn result in spurious variables showing up in the graph.
# This was very tricky to debug. For an example, dump the graph at call_user_compiler
# while running test_subgraphs.py
if isinstance(v.source, LocalSource) and v.source.local_name == k:
continue # no need to restore initial state
if v not in val_to_names:
val_to_names[v] = list()
val_to_names[v].append(k)
for v in val_to_names.keys():
restore_vars.extend(val_to_names[v])
stack_values.extend([v] * len(val_to_names[v]))
# to handle random calls
if len(tx.random_calls) > 0:
append_prefix_insts()
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
self.install_global(rand_fn_name, rand_fn)
codegen = PyCodegen(tx, root)
random_calls_instructions.extend(
codegen.load_function_name(rand_fn_name, True)
)
random_calls_instructions.extend(create_call_function(0, False))
random_calls_instructions.append(
codegen.create_store(tx.output.random_values_var),
)
self.add_output_instructions(random_calls_instructions)
if (
stack_values
and all(
not isinstance(
v,
(
UnspecializedPythonVariable,
NumpyNdarrayVariable,
TensorWithTFOverrideVariable,
),
)
for v in stack_values
)
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
)
else:
graph_output_var = self.new_var("graph_out")
pass1 = PyCodegen(tx, root, graph_output_var)
self.side_effects.codegen_hooks(pass1)
self.side_effects.codegen_save_tempvars(pass1)
pass1.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass1)
# one more time now that we have established tempvars
pass2 = PyCodegen(
tx,
root,
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
)
self.side_effects.codegen_hooks(pass2)
self.side_effects.codegen_save_tempvars(pass2)
pass2.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass2)
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
)
if len(pass2.graph_outputs) != 0:
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
append_prefix_insts()
self.add_output_instructions(output + pass2.get_instructions())
# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)
def cleanup_graph(self):
"""
Remove "creation_timestamp" from node meta
Remove this pattern from the graph:
torch._C._set_grad_enabled(False)
torch._C._set_grad_enabled(True)
"""
assert self.should_exit
nodes = list(self.graph.nodes)
for node in nodes:
node.meta.pop("creation_timestamp", None)
grad_enabled = torch.is_grad_enabled()
for node1, node2 in zip(nodes, nodes[1:]):
if (
node1.target is torch._C._set_grad_enabled
and tuple(node1.args) == (not grad_enabled,)
and not node1._erased
):
grad_enabled = node1.args[0]
if (
node2.target is torch._C._set_grad_enabled
and tuple(node2.args) == (not grad_enabled,)
and not node2._erased
):
grad_enabled = node2.args[0]
self.graph.erase_node(node1)
self.graph.erase_node(node2)
def get_graph_sizes_log_str(self, name):
graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
graph_sizes_str += f"===== {name} =====\n"
for node in self.graph.nodes:
example_value = node.meta.get("example_value", None)
if isinstance(example_value, torch._subclasses.FakeTensor):
size = example_value.size()
graph_sizes_str += f"{node.name}: {tuple(size)}\n"
concrete_size = []
has_symint = False
for sz in size:
if isinstance(sz, int):
concrete_size.append(sz)
elif isinstance(sz, torch.SymInt):
has_symint = True
concrete_size.append(sz.node.hint)
else:
break