Skip to content

Commit 2ed31f2

Browse files
lucylqfacebook-github-bot
authored andcommitted
Remove deepcopy (#2502)
Summary: Add fake program to remove deepcopy. See: D54826270 Test Plan: Imported from GitHub, without a `Test Plan:` line. Peak memory usage: 38.7 GiB, down from 45GiB. https://lookaside.facebook.com/intern/diff/file/data/?number=1470390332&download=1 Differential Revision: D55047794 Pulled By: lucylq
1 parent 98f679f commit 2ed31f2

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

exir/backend/backend_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
LoweredBackendModule,
2929
)
3030
from executorch.exir.pass_base import ExportPass
31+
from executorch.exir.program._fake_program import get_fake_program, get_real_program
3132
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
3233
from torch.export import ExportedProgram
3334

@@ -343,8 +344,9 @@ def to_backend(
343344
Returns:
344345
ExportedProgram: The input program, with some portions targeted for delegation.
345346
"""
346-
copied_edge_program = copy.deepcopy(edge_program)
347-
partitioner_result = partitioner_instance(copied_edge_program)
347+
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
348+
fake_edge_program = get_fake_program(edge_program)
349+
partitioner_result = partitioner_instance(fake_edge_program)
348350
tagged_exported_program = partitioner_result.tagged_exported_program
349351

350352
# Check that the partitioner did not modify the original graph
@@ -360,6 +362,7 @@ def to_backend(
360362
partitioner_result.partition_tags is not None
361363
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
362364

365+
get_real_program(tagged_exported_program, edge_program)
363366
tagged_graph_module = _partition_and_lower(
364367
tagged_exported_program.graph_module, partitioner_result, edge_program
365368
)

exir/program/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ python_library(
88
"__init__.py",
99
],
1010
deps = [
11+
":fake_program",
1112
":program",
1213
],
1314
)
@@ -38,3 +39,13 @@ python_library(
3839
"//executorch/exir/verification:verifier",
3940
],
4041
)
42+
43+
python_library(
44+
name = "fake_program",
45+
srcs = [
46+
"_fake_program.py",
47+
],
48+
deps = [
49+
"//caffe2:torch",
50+
],
51+
)

exir/program/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from executorch.exir.program._fake_program import get_fake_program
910
from executorch.exir.program._program import (
1011
_to_edge,
1112
edge_to_executorch_passes,
@@ -24,4 +25,6 @@
2425
"edge_to_executorch_passes",
2526
"EdgeProgramManager",
2627
"ExecutorchProgramManager",
28+
"get_fake_program",
29+
"get_real_program",
2730
]

exir/program/_fake_program.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from typing import Dict, Union
9+
10+
import torch
11+
12+
from torch._guards import detect_fake_mode
13+
from torch.export import ExportedProgram
14+
15+
16+
def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram:
17+
fake_mode = detect_fake_mode(
18+
tuple(
19+
node.meta["val"]
20+
for node in real_exported_program.graph.nodes
21+
if node.op == "placeholder"
22+
)
23+
)
24+
new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}
25+
for key, tensor in real_exported_program.state_dict.items():
26+
fake = fake_mode.from_tensor(tensor, static_shapes=True)
27+
new_state_dict[key] = fake
28+
29+
gm = copy.deepcopy(real_exported_program.graph_module)
30+
fake_exported_program = ExportedProgram(
31+
root=gm,
32+
graph=gm.graph,
33+
graph_signature=copy.deepcopy(real_exported_program.graph_signature),
34+
state_dict=new_state_dict,
35+
range_constraints=copy.deepcopy(real_exported_program.range_constraints),
36+
module_call_graph=copy.deepcopy(real_exported_program.module_call_graph),
37+
verifier=copy.deepcopy(real_exported_program.verifier),
38+
constants=real_exported_program.constants,
39+
)
40+
return fake_exported_program
41+
42+
43+
def get_real_program(
44+
fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram
45+
) -> None:
46+
fake_exported_program._state_dict = real_exported_program.state_dict

0 commit comments

Comments
 (0)