Skip to content

Commit f5d75ff

Browse files
lucylqfacebook-github-bot
authored andcommitted
Remove deepcopy (#2502)
Summary: Add fake program to remove deepcopy. See: D54826270 Peak memory usage: 38.7 GiB, down from 45GiB. https://lookaside.facebook.com/intern/diff/file/data/?number=1470390332&download=1 Pull Request resolved: #2502 Test Plan: Imported from GitHub, without a `Test Plan:` line. ``` memray run -m examples.models.llama2.export_llama -c ../llama-models/llama7b/consolidated.00.pth -p ../llama-models/llama7b/config.json -kv --use_sdpa_with_kv_cache -d fp32 --pt2e_quantize "xnnpack_dynamic" ``` ``` buck2 run fbcode//executorch/exir/program/test:test_program ``` Reviewed By: cccclai Differential Revision: D55047794 Pulled By: lucylq fbshipit-source-id: b2834c05681f4b9d0ea30b92376678e9f3d36736
1 parent dd99909 commit f5d75ff

File tree

6 files changed

+168
-2
lines changed

6 files changed

+168
-2
lines changed

exir/backend/backend_api.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
LoweredBackendModule,
2929
)
3030
from executorch.exir.pass_base import ExportPass
31+
from executorch.exir.program._fake_program import (
32+
get_fake_program,
33+
update_to_real_program,
34+
)
3135
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
3236
from torch.export import ExportedProgram
3337

@@ -343,8 +347,14 @@ def to_backend(
343347
Returns:
344348
ExportedProgram: The input program, with some portions targeted for delegation.
345349
"""
346-
copied_edge_program = copy.deepcopy(edge_program)
347-
partitioner_result = partitioner_instance(copied_edge_program)
350+
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
351+
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
352+
try:
353+
fake_edge_program = get_fake_program(edge_program)
354+
except AssertionError as e:
355+
logging.warning(f"No fake mode found for {edge_program.graph_module}: {e}")
356+
fake_edge_program = copy.deepcopy(edge_program)
357+
partitioner_result = partitioner_instance(fake_edge_program)
348358
tagged_exported_program = partitioner_result.tagged_exported_program
349359

350360
# Check that the partitioner did not modify the original graph
@@ -360,6 +370,7 @@ def to_backend(
360370
partitioner_result.partition_tags is not None
361371
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
362372

373+
update_to_real_program(tagged_exported_program, edge_program)
363374
tagged_graph_module = _partition_and_lower(
364375
tagged_exported_program.graph_module, partitioner_result, edge_program
365376
)

exir/program/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ python_library(
88
"__init__.py",
99
],
1010
deps = [
11+
":fake_program",
1112
":program",
1213
],
1314
)
@@ -38,3 +39,13 @@ python_library(
3839
"//executorch/exir/verification:verifier",
3940
],
4041
)
42+
43+
python_library(
44+
name = "fake_program",
45+
srcs = [
46+
"_fake_program.py",
47+
],
48+
deps = [
49+
"//caffe2:torch",
50+
],
51+
)

exir/program/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from executorch.exir.program._fake_program import get_fake_program
910
from executorch.exir.program._program import (
1011
_to_edge,
1112
edge_to_executorch_passes,
@@ -24,4 +25,6 @@
2425
"edge_to_executorch_passes",
2526
"EdgeProgramManager",
2627
"ExecutorchProgramManager",
28+
"get_fake_program",
29+
"get_real_program",
2730
]

exir/program/_fake_program.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from typing import Dict, Union
9+
10+
import torch
11+
12+
from torch._guards import detect_fake_mode
13+
from torch.export import ExportedProgram
14+
15+
16+
def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram:
17+
"""Create a fake exported program. This uses fake tensors for the state dict
18+
to prevent mutation, and points to the real constants, to avoid large memory
19+
usage from copying when constants are large.
20+
21+
Args:
22+
real_exported_program: the original exported program
23+
Returns:
24+
A new exported program, with fake tensors.
25+
"""
26+
fake_mode = detect_fake_mode(
27+
tuple(
28+
node.meta["val"]
29+
for node in real_exported_program.graph.nodes
30+
if node.op == "placeholder"
31+
)
32+
)
33+
if fake_mode is None:
34+
raise AssertionError(
35+
"Could not detect fake mode for graph: ", real_exported_program.graph
36+
)
37+
38+
new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}
39+
40+
for key, tensor in real_exported_program.state_dict.items():
41+
fake = fake_mode.from_tensor(tensor, static_shapes=True)
42+
new_state_dict[key] = fake
43+
44+
gm = copy.deepcopy(real_exported_program.graph_module)
45+
fake_exported_program = ExportedProgram(
46+
root=gm,
47+
graph=gm.graph,
48+
graph_signature=copy.deepcopy(real_exported_program.graph_signature),
49+
state_dict=new_state_dict,
50+
range_constraints=copy.deepcopy(real_exported_program.range_constraints),
51+
module_call_graph=copy.deepcopy(real_exported_program.module_call_graph),
52+
verifier=real_exported_program.verifier,
53+
constants=real_exported_program.constants,
54+
)
55+
return fake_exported_program
56+
57+
58+
def update_to_real_program(
59+
fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram
60+
) -> None:
61+
"""Update the fake exported program to point to the real state dict. Modifies the
62+
fake exported program in-place.
63+
"""
64+
fake_exported_program._state_dict = real_exported_program.state_dict

exir/program/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ python_unittest(
66
# @autodeps-skip pybindings don't work well with autodeps
77
name = "test_program",
88
srcs = [
9+
"test_fake_program.py",
910
"test_program.py",
1011
],
1112
deps = [
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import sys
9+
import unittest
10+
11+
import torch
12+
13+
from executorch.exir.program._fake_program import (
14+
get_fake_program,
15+
update_to_real_program,
16+
)
17+
from torch.export import export, ExportedProgram
18+
19+
20+
def get_exported_program() -> ExportedProgram:
21+
class Linear(torch.nn.Module):
22+
def __init__(self):
23+
super().__init__()
24+
self.linear = torch.nn.Linear(10, 10)
25+
self.register_buffer("buf", torch.randn(10, 10), persistent=False)
26+
27+
def forward(self, arg) -> torch.Tensor:
28+
return self.linear(arg) + self.buf
29+
30+
linear = Linear()
31+
exported_program = export(
32+
linear,
33+
args=(torch.randn(10, 10),),
34+
).run_decompositions()
35+
return exported_program
36+
37+
38+
class TestFakeProgram(unittest.TestCase):
39+
def setUp(self) -> None:
40+
super().setUp()
41+
42+
def test_fake_program(self) -> None:
43+
exported_program = get_exported_program()
44+
fake_program = get_fake_program(exported_program)
45+
print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}")
46+
print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}")
47+
48+
# Fake program deep copies attributes besides verifier, state_dict and constants.
49+
self.assertEqual(exported_program.graph_signature, fake_program.graph_signature)
50+
self.assertNotEqual(
51+
id(exported_program.graph_signature), id(fake_program.graph_signature)
52+
)
53+
self.assertEqual(
54+
exported_program.module_call_graph, fake_program.module_call_graph
55+
)
56+
self.assertNotEqual(
57+
id(exported_program.module_call_graph), id(fake_program.module_call_graph)
58+
)
59+
60+
# Verifier is static.
61+
self.assertEqual(exported_program.verifier, fake_program.verifier)
62+
self.assertEqual(id(exported_program.verifier), id(fake_program.verifier))
63+
64+
# Fake program uses fake tensors for the state dict. Size should be smaller.
65+
self.assertLess(
66+
sys.getsizeof(fake_program.state_dict),
67+
sys.getsizeof(exported_program.state_dict),
68+
)
69+
70+
# Do not copy constants.
71+
self.assertEqual(exported_program.constants, fake_program.constants)
72+
self.assertEqual(id(exported_program.constants), id(fake_program.constants))
73+
74+
update_to_real_program(fake_program, exported_program)
75+
self.assertEqual(exported_program.state_dict, fake_program.state_dict)
76+
self.assertEqual(id(exported_program.state_dict), id(fake_program.state_dict))

0 commit comments

Comments
 (0)