forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_trace.py
More file actions
2598 lines (2276 loc) · 101 KB
/
_trace.py
File metadata and controls
2598 lines (2276 loc) · 101 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import dataclasses
import functools
import inspect
import logging
import re
import sys
import time
import warnings
from collections.abc import Callable
from contextlib import contextmanager, ExitStack, nullcontext
from itertools import chain
from typing import Any, TYPE_CHECKING, TypeAlias
from unittest import mock
if TYPE_CHECKING:
import weakref
import torch
import torch._dynamo
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.exc import UserError, UserErrorType
from torch._export.db.logging import (
exportdb_error_message,
get_class_if_classified_error,
)
from torch._export.non_strict_utils import (
_fakify_module_inputs,
_fakify_script_objects,
_gather_constant_attrs,
_NonStrictTorchFunctionHandler,
_override_builtin_ops,
make_constraints,
make_fake_inputs,
produce_guards_and_solve_constraints,
)
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
from torch._export.passes.lift_constants_pass import (
_materialize_and_lift_constants,
ConstantAttrMap,
)
from torch._export.utils import (
_collect_param_buffer_metadata,
_compiling_state_context,
_fakify_params_buffers,
_populate_param_buffer_metadata_to_new_gm,
_update_gm_meta_if_possible,
apply_runtime_assertion_pass,
placeholder_naming_pass,
placeholder_prefixes,
)
from torch._export.verifier import SpecViolationError
from torch._export.wrappers import _wrap_submodules
from torch._functorch._aot_autograd.graph_capture_wrappers import create_functional_call
from torch._functorch._aot_autograd.input_output_analysis import (
_graph_input_names,
_graph_output_names,
)
from torch._functorch._aot_autograd.schemas import GraphSignature
from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container
from torch._functorch._aot_autograd.utils import (
create_tree_flattened_fn,
register_buffer_assignment_hook,
)
from torch._functorch.aot_autograd import (
_detect_attribute_assignment,
aot_export_joint_with_descriptors,
)
from torch._guards import detect_fake_mode, tracing, TracingContext
from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj
from torch._library.opaque_object import is_opaque_type
from torch._logging import dtrace_structured
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._utils_internal import compile_time_strobelight_meta, log_export_usage
from torch.export._leakage_detection_utils import find_legit_leaks_from_referrers
from torch.export._unlift import _check_input_constraints_pre_hook
from torch.export.dynamic_shapes import (
_check_dynamic_shapes,
_combine_args,
_DimHintType,
_IntWrapper,
_process_dynamic_shapes,
)
from torch.export.exported_program import OutputKind
from torch.fx._symbolic_trace import _ConstantAttributeType
from torch.fx.experimental.proxy_tensor import (
get_proxy_slot,
make_fx,
PreDispatchTorchFunctionMode,
track_tensor_tree,
)
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
free_unbacked_symbols,
GuardOnDataDependentSymNode,
ShapeEnv,
)
from torch.fx.graph import _PyTreeInfo
from torch.utils._pytree import TreeSpec
from torch.utils._sympy.value_ranges import ValueRangeError
from .exported_program import (
_disable_prexisiting_fake_mode,
ExportedProgram,
InputKind,
ModuleCallEntry,
ModuleCallSignature,
)
from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature
log = logging.getLogger(__name__)
# Type alias for dynamic shapes specification
_DynamicShapesSpec: TypeAlias = dict[str, Any] | tuple[Any, ...] | list[Any]
@dataclasses.dataclass
class ExportDynamoConfig:
"""
Manage Export-specific configurations of Dynamo.
"""
allow_rnn: bool = True
reorderable_logging_functions: set[Callable] = dataclasses.field(
default_factory=set
)
# Emit runtime asserts after AOTAutograd instead.
# This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE,
# but if we want to reason more about what guards/runtime asserts to emit,
# this makes it a bit cleaner to do from the export side. Also no real point in running this twice.
do_not_emit_runtime_asserts: bool = True
specialize_int: bool = True
specialize_float: bool = True
assume_static_by_default: bool = False
automatic_dynamic_shapes: bool = False
capture_dynamic_output_shape_ops: bool = True
capture_scalar_outputs: bool = True
prefer_deferred_runtime_asserts_over_guards: bool = False
replay_side_effects: bool = False
side_effect_replay_policy: str = "warn"
@dataclasses.dataclass
class ATenExportArtifact:
gm: torch.fx.GraphModule
sig: ExportGraphSignature
constants: dict[str, _ConstantAttributeType]
inferred_out_spec: TreeSpec
@dataclasses.dataclass(frozen=True)
class ExportArtifact:
aten: ATenExportArtifact
in_spec: TreeSpec
out_spec: TreeSpec
fake_mode: FakeTensorMode
module_call_specs: dict[str, dict[str, pytree.TreeSpec]]
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig()
DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = {
logging.critical,
logging.debug,
logging.error,
logging.exception,
logging.info,
logging.log,
logging.warning,
print,
warnings.warn,
}
@contextmanager
def _ignore_backend_decomps():
orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False)
orig_nnpack_flag = torch.backends.nnpack.set_flags(False)
orig_cudnn_flag = torch.backends.cudnn.set_flags(False)
try:
yield
finally:
torch.backends.mkldnn.set_flags(*orig_mkldnn_flag)
torch.backends.nnpack.set_flags(*orig_nnpack_flag)
torch.backends.cudnn.set_flags(*orig_cudnn_flag)
@contextmanager
def _disable_custom_triton_op_functional_decomposition():
old = torch._functorch.config.decompose_custom_triton_ops
try:
# pyrefly: ignore [bad-assignment]
torch._functorch.config.decompose_custom_triton_ops = False
yield torch._functorch.config.decompose_custom_triton_ops
finally:
torch._functorch.config.decompose_custom_triton_ops = old
def custom_triton_ops_decomposition_disabled():
return not torch._functorch.config.decompose_custom_triton_ops
def _fixup_key(x):
return "L__self__" + _strip_root(x)
def _strip_root(x):
if isinstance(x, str) and x.startswith("_export_root"):
stripped = x[len("_export_root") :]
return stripped.removeprefix(".")
return x
def _is_bogus_const_name(name: str):
splitted_names = name.split(".")
if len(splitted_names) < 1:
return True
return splitted_names[-1].startswith("lifted_tensor")
def _rewrite_tracepoint_node(gm: torch.fx.GraphModule):
"""
In-place modify input graph module by replacing the export tracepoint with a new node
that has the same target and args, but with the _export_root stripped from path.
"""
for node in gm.graph.nodes:
if node.target is torch.ops.higher_order._export_tracepoint:
if "path" in node.kwargs:
path = _strip_root(node.kwargs["path"])
with gm.graph.inserting_before(node):
new_node = gm.graph.create_node(
"call_function",
torch.ops.higher_order._export_tracepoint,
args=node.args,
kwargs={
"path": path,
"kind": node.kwargs["kind"],
},
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
gm.graph.erase_node(node)
def detect_shape_env(inputs: Any = None):
shape_envs = []
for i, flat_input in enumerate(inputs):
if isinstance(flat_input, torch.SymInt):
shape_envs.append((flat_input.node.shape_env, "symint input", i))
if shape_envs:
shape_env, desc1, i1 = shape_envs[0]
for m, desc2, i2 in shape_envs[1:]:
if shape_env is not m:
raise AssertionError(
f"shape env ({shape_env}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
f"shape env from {desc1} {i1} allocated at:\n{shape_env.stack}\n"
f"shape env from {desc2} {i2} allocated at:\n{m.stack}"
)
return shape_env
else:
return None
def _extract_fake_inputs(gm, args, kwargs):
"""
Given a graph module, extract fakified input tensors from the metadata of
its placeholders, and map them to the structure of given args and kwargs.
Also return the fake mode used to fakify those inputs.
"""
fake_inps: list[Any] = []
fake_vals: list[Any] = []
for node in gm.graph.nodes:
if node.op == "placeholder":
fake_inps.append(node.meta.get("val"))
else:
fake_vals.append(node.meta.get("example_value"))
if dynamo_bytecode_flatten := getattr(gm, "_dynamo_bytecode_flatten", None):
# In _extract_fake_inputs, the goal is to make real inputs into
# fake (and symbolic) inputs. The way currently it's implemented
# is by looking at the node.meta["val"] of the placeholder nodes.
# This doesn't work when the graph is Dynamo flattened, because now
# plceholder nodes doesn't have the ordering like pytree inputs do.
# Instead, we need to look at how the inputs are shuffled, and map
# the inputs to their actual fake inputs and symbolic inputs.
# Since inputs can also contain symints, we cannot simply use the
# FakeTensorMode memo to look up tensors only there.
fake_inps = []
positions = {}
idx = 0
def mark_inputs(x):
# x can be a tensor or symbolic integer or a normal constant.
nonlocal idx
fake_inps.append(x)
if isinstance(x, torch.Tensor):
ret = x
else:
ret = object()
if id(ret) not in positions:
positions[id(ret)] = idx
idx += 1
return ret
dummy_args = pytree.tree_map(mark_inputs, args + tuple(kwargs.values()))
shuffled_args = dynamo_bytecode_flatten(*dummy_args)
for node, shuffled_arg in zip(
gm.graph.find_nodes(op="placeholder"), shuffled_args
):
if id(shuffled_arg) in positions:
fake_inps[positions[id(shuffled_arg)]] = node.meta.get("val")
# We get both because now we might have a combination of symint and tensor
# inputs, and we want to check that the shape env is consistent between
# both. Unfortunately we can't see what fake mode is attached to the shape
# env, then we can just compare fake modes.
detected_fake_mode = detect_fake_mode(fake_inps + fake_vals)
detected_shape_env = detect_shape_env(fake_inps + fake_vals)
if detected_fake_mode:
if detected_shape_env:
if detected_shape_env is not detected_fake_mode.shape_env:
raise AssertionError(
"Detected shape env does not match fake mode's shape env"
)
fake_mode = detected_fake_mode
elif detected_shape_env:
fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True)
else:
fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True)
count = 0
def lookup_fake(x):
nonlocal count
val = fake_inps[count] if isinstance(x, (int, torch.Tensor)) else x
count += 1
return val
fake_args = pytree.tree_map(lookup_fake, args)
fake_kwargs = pytree.tree_map(lookup_fake, kwargs)
return fake_args, fake_kwargs, fake_mode
def _replace_param_buffer_names(param_buffer_table, sig):
for spec in sig.input_specs:
if spec.kind in (
InputKind.PARAMETER,
InputKind.BUFFER,
):
spec.target = param_buffer_table[spec.target]
for spec in sig.output_specs:
if spec.kind in (
OutputKind.BUFFER_MUTATION,
OutputKind.GRADIENT_TO_PARAMETER,
):
spec.target = param_buffer_table[spec.target]
def _convert_to_positional_args(orig_arg_names, args, kwargs):
if len(orig_arg_names) != len(args) + len(kwargs):
raise AssertionError(
f"Total number of arg names is expected to be {len(orig_arg_names)} "
f"but got {len(args)} positional args, {len(kwargs)} kwargs."
)
reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]]
return (
*args,
*reordered_kwargs,
)
def _normalize_nn_module_stack(gm_torch_level, root_cls):
# Append a root module to every nn_module_stack.
root = "L['self']"
root_key = re.sub(r"[^a-zA-Z0-9]", "_", root)
for gm in gm_torch_level.modules():
if not isinstance(gm, torch.fx.GraphModule):
continue
for node in gm.graph.nodes:
if node.op in ["placeholder", "output"]:
continue
add_root = True
if nn_module_stack := node.meta.get("nn_module_stack", {}):
path, ty = next(iter(nn_module_stack.values()))
# After deserializing the class `ty` might not exist anymore so
# it could be a string
if inspect.isclass(ty) and issubclass(ty, torch.nn.Module):
# TODO Figure out why sometimes we have root sometimes we don't.
if path == root and ty is root_cls:
add_root = False
else:
if not isinstance(ty, str):
raise AssertionError(f"expected ty to be str, got {type(ty)}")
if add_root:
def normalize_path(path):
if path == "L['self']":
return ""
if path.startswith("L['self']."):
return path[len("L['self'].") :]
return path
nn_module_stack = {
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
# pyrefly: ignore [unbound-name]
**nn_module_stack,
}
node.meta["nn_module_stack"] = {
key: (normalize_path(path), ty)
for key, (path, ty) in nn_module_stack.items()
}
def _get_param_buffer_mapping(
original_module: torch.nn.Module,
traced_module: torch.nn.Module,
) -> dict[str, str]:
"""
Returns a mapping of parameter/buffer names from the new module to the
original model. This is to help with restoring the FQN for parameter/buffers
of a traced module to what the original module contains.
"""
param_lookup: dict[int, str] = {}
buffer_lookup: dict[int, str] = {}
for name, param in original_module.named_parameters(remove_duplicate=False):
if param_lookup.get(id(param)) is None:
# we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module.
param_lookup[id(param)] = name
for name, buffer in original_module.named_buffers(remove_duplicate=False):
buffer_lookup[id(buffer)] = name
param_buffer_table: dict[str, str] = {}
for dynamo_name, dynamo_param in traced_module.named_parameters(
remove_duplicate=False
):
if dynamo_name in param_buffer_table:
raise AssertionError(
f"dynamo_name {dynamo_name!r} already exists in param_buffer_table"
)
if id(dynamo_param) in param_lookup:
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)]
for dynamo_name, dynamo_buffer in traced_module.named_buffers(
remove_duplicate=False
):
if dynamo_name in param_buffer_table:
raise AssertionError(
f"dynamo_name {dynamo_name!r} already exists in param_buffer_table for buffer"
)
if id(dynamo_buffer) in buffer_lookup:
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)]
return param_buffer_table
def _preserve_requires_grad_pass(
gm: torch.fx.GraphModule,
sig: ExportGraphSignature,
fake_params_buffers: dict[str, torch.Tensor],
constants: dict[str, _ConstantAttributeType],
flat_fake_args: list[Any],
):
placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
if len(sig.input_specs) != len(placeholders):
raise AssertionError(
f"input_specs length {len(sig.input_specs)} does not match placeholders length {len(placeholders)}"
)
i = 0
for node, spec in zip(placeholders, sig.input_specs):
if spec.kind in (
InputKind.PARAMETER,
InputKind.BUFFER,
):
if spec.target is None:
raise AssertionError(
f"spec.target must not be None for kind {spec.kind}"
)
node.meta["val"].requires_grad = fake_params_buffers[
spec.target
].requires_grad
elif spec.kind == InputKind.USER_INPUT:
fake_arg = flat_fake_args[i]
if isinstance(fake_arg, torch.Tensor):
node.meta["val"].requires_grad = fake_arg.requires_grad
i += 1
elif spec.kind == InputKind.CONSTANT_TENSOR:
if spec.target is None:
raise AssertionError(
"spec.target must not be None for CONSTANT_TENSOR kind"
)
constant = constants[spec.target]
if isinstance(constant, torch.Tensor):
# If the tensor is not leaf, it should already have a correct requires grad field
if node.meta["val"].is_leaf:
node.meta["val"].requires_grad = constant.requires_grad
else:
if node.meta["val"].requires_grad != constant.requires_grad:
raise AssertionError(
f"node requires_grad {node.meta['val'].requires_grad} does not match "
f"constant requires_grad {constant.requires_grad}"
)
elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN):
continue
else:
raise AssertionError(spec.kind)
def _remap_constants(
orig_constant_attrs: ConstantAttrMap,
graph_signature: ExportGraphSignature,
constants: dict[str, _ConstantAttributeType],
) -> None:
"""Rewrite the graph signature and constants table to use the FQN from the original module."""
remap_table: dict[str, list[str]] = {}
for name, value in constants.items():
if value in orig_constant_attrs:
remap_table[name] = orig_constant_attrs[value]
for spec in graph_signature.input_specs:
if spec.kind in (
InputKind.CONSTANT_TENSOR,
InputKind.CUSTOM_OBJ,
):
orig_target = spec.target
if orig_target is None:
raise AssertionError(
f"spec.target must not be None for kind {spec.kind}"
)
targets = remap_table.get(orig_target, [orig_target])
spec.target = targets[0]
constant = constants[orig_target]
del constants[orig_target]
for target in targets:
constants[target] = constant
def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
"""
When we run an interpreter-based pass over a GraphModule, execution of data-dependent operators
will produce example values with new unbacked symbols. To track that the new/old symbols are equivalent,
we used to rely on the unbacked_renamings mapping. This led to problematic metadata where the unbacked_bindings
keys mapped new symbols (u2) to paths containing old symbols (u0) in the example values, or worse, backed symbols
or constants (e.g. if the original unbacked was replaced/specialized). Additionally this created problems with
de/serialized programs, since we didn't comprehensively serialize ShapeEnv/unbacked renamings/node bindings.
This pass attempts a simpler way of handling these for export, by throwing away the previously computed bindings, and re-running
the pattern match used in compute_unbacked_bindings. This ensures we keep the original symbols contained in the example values,
or delete bindings if they've been replaced/specialized.
"""
from torch._export.utils import _get_shape_env_from_gm
from torch.fx.experimental.symbolic_shapes import _free_unbacked_symbols_with_path
from torch.utils._sympy.symbol import symbol_is_type, SymT
if (shape_env := _get_shape_env_from_gm(gm)) is None:
return
base_unbacked_symbols = {
symbol
for symbol in shape_env.var_to_range
if symbol_is_type(symbol, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT))
and symbol not in shape_env.unbacked_renamings
}
for node in gm.graph.nodes:
node.meta.pop("unbacked_bindings", None)
if (val := node.meta.get("val")) is not None and (
unbacked_bindings := _free_unbacked_symbols_with_path(
val,
(),
shape_env=shape_env,
pending=base_unbacked_symbols,
simplify=True,
)
):
node.meta["unbacked_bindings"] = unbacked_bindings
def _produce_aten_artifact(
*,
gm: torch.fx.GraphModule,
mod,
constant_attrs,
graph_signature,
pre_dispatch,
fake_args,
fake_kwargs,
fake_params_buffers,
_prettify_placeholder_names=True,
) -> ATenExportArtifact:
"""
This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx
to produce the aten artifact. (export compatible graph module + signature)
It does:
1. Applies runtime assertion pass
2. Recompute unbacked_bindings pass
3. Populate meta val when missing
4. Lift constants as placeholders
5. Replace raw autograd and autocast ops with HOPs
6. Prettify names for placeholders
7. Preserve requires_grad value on node meta val
"""
# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
# Overwrite output specs afterwards.
flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature)
# Simplify unbacked_bindings by recomputing them.
# Useful for any pass that's interpreter-based and might call rebind_unbacked(),
# e.g. AOTAutograd in this case.
_replace_unbacked_bindings(gm)
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs)
export_graph_signature: ExportGraphSignature | None
export_graph_signature = _convert_to_export_graph_signature(
graph_signature, gm, _get_non_persistent_buffers(mod)
)
# script objects are always stored in constants no matter whether they're initial inputs or
# they're lifted in aot" before rewrite_script_object_meta
constants = _materialize_and_lift_constants(
gm, export_graph_signature, constant_attrs
)
if pre_dispatch:
from torch._export.passes.replace_autocast_with_hop_pass import (
replace_autocast_with_hop_pass,
)
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)
# Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because
# a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass.
# If replace_set_grad_with_hop_pass is before lift_constant_pass,
# and the constant_tensor is passed as input of the set grad hop, the placeholder's
# meta["val"] will be None and fails our verifier for placeholder.
gm, export_graph_signature = replace_set_grad_with_hop_pass(
gm, export_graph_signature
)
gm, export_graph_signature = replace_autocast_with_hop_pass(
gm, export_graph_signature
)
# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)
# Prettify names for placeholder nodes.
if export_graph_signature is None:
raise AssertionError("export_graph_signature must not be None")
if _prettify_placeholder_names:
placeholder_naming_pass(
gm,
export_graph_signature,
mod,
fake_args,
fake_kwargs,
fake_params_buffers,
constants,
)
_preserve_requires_grad_pass(
gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args
)
return ATenExportArtifact(
gm,
export_graph_signature,
constants,
inferred_out_spec=graph_signature.out_spec,
)
def _rename_constants_nodes(
gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
) -> None:
"""
For strict mode, rename constants nodes that were previously annotated as buffers.
"""
# handle name collisions with existing constants
node_names = {node.name for node in gm.graph.nodes}
def rename_constant(name):
if name in node_names:
n = 1
while (dup_name := f"{name}_{n}") in node_names:
n += 1
# pyrefly: ignore [unbound-name]
name = dup_name
node_names.add(name)
return name
# use input specs to map names from buffers to constants
buffer_prefix = placeholder_prefixes[InputKind.BUFFER]
const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR]
buffer_to_constant = {}
for spec in graph_signature.input_specs:
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
const_prefix
):
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
c_name = rename_constant(
const_prefix + spec.arg.name[len(buffer_prefix) :]
)
else: # lifted constant
c_name = rename_constant(const_prefix + spec.arg.name)
buffer_to_constant[spec.arg.name] = c_name
spec.arg.name = c_name
for spec in graph_signature.output_specs:
if spec.arg.name in buffer_to_constant:
spec.arg.name = buffer_to_constant[spec.arg.name]
# Rename constants nodes for all modules
for mod in gm.modules():
if not isinstance(mod, torch.fx.GraphModule):
continue
for node in mod.graph.nodes:
if node.name in buffer_to_constant:
node.name = node.target = buffer_to_constant[node.name]
mod.recompile()
def _restore_state_dict(
original_module: torch.nn.Module, traced_module: torch.fx.GraphModule
) -> None:
"""
Restores the state dict of the traced module to that of the original module.
"""
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module)
# Don't want to change the convention of previous call.
param_buffer_table_reverse = {v: k for k, v in param_buffer_table.items()}
# Replace state dict attr names with the fqn
for name, _ in list(
chain(
original_module.named_parameters(remove_duplicate=False),
# pyrefly: ignore [bad-argument-type]
original_module.named_buffers(remove_duplicate=False),
)
):
if name in param_buffer_table_reverse:
dynamo_name = param_buffer_table_reverse[name]
param = torch.fx.graph_module._get_attr(traced_module, dynamo_name)
torch.fx.graph_module._assign_attr(param, traced_module, name)
torch.fx.graph_module._del_attr(traced_module, dynamo_name)
# Replace graph getattr nodes with the correct name
for node in traced_module.graph.nodes:
if node.op == "get_attr":
attr_name = node.target
if attr_name in param_buffer_table:
node.target = param_buffer_table[attr_name]
traced_module.recompile()
def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]:
return {
name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False)
}
def _make_module_call_graph(
in_spec: TreeSpec,
out_spec: TreeSpec,
module_call_signatures: dict[str, ModuleCallSignature],
forward_arg_names: list[str] | None = None,
) -> list[ModuleCallEntry]:
original = [
ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn))
for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr]
]
if original[0].fqn != "":
raise AssertionError(
f"expected first fqn to be empty string, got {original[0].fqn!r}"
)
original[0].signature = ModuleCallSignature(
inputs=[],
outputs=[],
in_spec=in_spec,
out_spec=out_spec,
forward_arg_names=forward_arg_names,
)
additional = [
ModuleCallEntry(fqn=fqn, signature=signature)
for fqn, signature in module_call_signatures.items()
if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator]
]
return [*original, *additional]
class _ExportModuleSpecTrackerDict(dict):
pass
def _export_to_torch_ir(
f: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any] | None = None,
dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None,
*,
preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
same_signature: bool = True,
) -> torch.fx.GraphModule:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
operations inside and produce a torch.fx.GraphModule in torch IR.
"""
if _log_export_usage:
log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"})
if not isinstance(args, tuple):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
)
kwargs = kwargs or {}
# Map ints to a wrapper structure to help us mark it as dynamic, if it is
# dynamic. We will unwrap ints in fakify later.
args, kwargs = pytree.tree_map_only(int, _IntWrapper, (args, kwargs))
combined_args = _combine_args(f, args, kwargs)
_check_dynamic_shapes(combined_args, dynamic_shapes)
constraints = _process_dynamic_shapes(combined_args, dynamic_shapes)
# Unwrap static ints -- in the case where we have an empty graph
# containing just integer computation, dynamo will run its generated
# bytecode with these args/kwargs, which will error because we cannot
# directly apply int operations on IntWrapper. So we will just unwrap
# them here.
args, kwargs = pytree.tree_map_only(
_IntWrapper,
lambda a: a.val
if a.dynamism is None or a.dynamism.type == _DimHintType.STATIC
else a,
(args, kwargs),
)
dynamo_cfg = dataclasses.replace(
DEFAULT_EXPORT_DYNAMO_CONFIG,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
def use_legacy_dynamo_graph_capture() -> bool:
return bool(
constraints # dynamic shape
or dynamic_shapes # dynamic shape
or isinstance(f, torch.fx.GraphModule) # retracing
or preserve_module_call_signature # unflatten
or torch._functorch.config.fake_tensor_propagate_real_tensors # draft
or torch._export.config.use_legacy_dynamo_graph_capture
)
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
try:
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
_ExportModuleSpecTrackerDict()
)
ctx = nullcontext()
if not isinstance(f, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
f, preserve_module_call_signature, module_call_specs
)
with ctx, _ignore_backend_decomps():
if torch._export.config.use_new_tracer_experimental:
from torch._dynamo.functional_export import (
_dynamo_graph_capture_for_export,
dynamo_graph_capture_for_export,
)
if use_legacy_dynamo_graph_capture():
dynamo_graph_capture = _dynamo_graph_capture_for_export(
f, constraints=constraints, dynamic_shapes=dynamic_shapes
)
else:
dynamo_graph_capture = torch._dynamo.config.patch(
replay_side_effects=False
)(dynamo_graph_capture_for_export(f))
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
# Once we delete the old strict export, we can use
gm_torch_level = dynamo_graph_capture(*args, **kwargs)
# We can't serialize entire fake mode yet, so this is to make sure
# things like copy.deepcopy(ep.graph_module) not crash.
# see test_export.py::test_custom_tag_metadata_re_export
# Once we delete the old strict export, we can use this fake mode in the
# subsequent logic when lowering to aten IR.
del gm_torch_level.meta["fake_mode"]
else:
gm_torch_level, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
constraints=constraints, # type: ignore[arg-type]
assume_static_by_default=True,
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
*args,
**kwargs,
)
gm_torch_level.meta["module_call_specs"] = module_call_specs
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
except GuardOnDataDependentSymNode as e:
raise UserError( # noqa: B904
UserErrorType.ANTI_PATTERN,
f"Consider annotating your code using torch._check*(). {str(e)}",
case_name="constrain_as_size_example",
)
if isinstance(f, torch.nn.Module) and restore_fqn:
_restore_state_dict(f, gm_torch_level)
return gm_torch_level
def _aot_export_joint_with_descriptors(
stack,
mod,
args,
*,
kwargs,
decompositions,
fake_params_buffers,
_record_nn_module_stack=True,
):
from torch._functorch._aot_autograd.graph_compile import aot_stage2_export
from torch._functorch._aot_autograd.input_output_analysis import (
create_graph_signature,
)
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
mod,
args,
kwargs=kwargs,
decompositions=decompositions,
_record_nn_module_stack=_record_nn_module_stack,
)
# Convert JointWithDescriptors to graph module and ViewAndMutationMeta
gm, fw_metadata = aot_stage2_export(
joint_with_descriptors._aot_state,
joint_with_descriptors._aot_graph_capture,
)
if not isinstance(gm, torch.fx.GraphModule):
raise AssertionError(f"expected gm to be torch.fx.GraphModule, got {type(gm)}")
# Create GraphSignature from the metadata
graph_signature = create_graph_signature(
gm,
fw_metadata,
joint_with_descriptors.in_spec,
joint_with_descriptors.out_spec,
user_args_flat=pytree.tree_leaves((args, kwargs)),
params_and_buffers_flat=list(fake_params_buffers.values()),
param_names=joint_with_descriptors.params_spec,
buffer_names=joint_with_descriptors.buffers_spec,
trace_joint=False,
num_user_fw_outs=None,
loss_index=None,