Skip to content

[executorch][aot] Remove deepcopy #2502

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
LoweredBackendModule,
)
from executorch.exir.pass_base import ExportPass
from executorch.exir.program._fake_program import (
get_fake_program,
update_to_real_program,
)
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.export import ExportedProgram

Expand Down Expand Up @@ -343,8 +347,14 @@ def to_backend(
Returns:
ExportedProgram: The input program, with some portions targeted for delegation.
"""
copied_edge_program = copy.deepcopy(edge_program)
partitioner_result = partitioner_instance(copied_edge_program)
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
try:
fake_edge_program = get_fake_program(edge_program)
except AssertionError as e:
logging.warning(f"No fake mode found for {edge_program.graph_module}: {e}")
fake_edge_program = copy.deepcopy(edge_program)
partitioner_result = partitioner_instance(fake_edge_program)
tagged_exported_program = partitioner_result.tagged_exported_program

# Check that the partitioner did not modify the original graph
Expand All @@ -360,6 +370,7 @@ def to_backend(
partitioner_result.partition_tags is not None
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"

update_to_real_program(tagged_exported_program, edge_program)
tagged_graph_module = _partition_and_lower(
tagged_exported_program.graph_module, partitioner_result, edge_program
)
Expand Down
11 changes: 11 additions & 0 deletions exir/program/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ python_library(
"__init__.py",
],
deps = [
":fake_program",
":program",
],
)
Expand Down Expand Up @@ -38,3 +39,13 @@ python_library(
"//executorch/exir/verification:verifier",
],
)

python_library(
name = "fake_program",
srcs = [
"_fake_program.py",
],
deps = [
"//caffe2:torch",
],
)
3 changes: 3 additions & 0 deletions exir/program/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from executorch.exir.program._fake_program import get_fake_program
from executorch.exir.program._program import (
_to_edge,
edge_to_executorch_passes,
Expand All @@ -24,4 +25,6 @@
"edge_to_executorch_passes",
"EdgeProgramManager",
"ExecutorchProgramManager",
"get_fake_program",
"get_real_program",
]
64 changes: 64 additions & 0 deletions exir/program/_fake_program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
from typing import Dict, Union

import torch

from torch._guards import detect_fake_mode
from torch.export import ExportedProgram


def get_fake_program(real_exported_program: ExportedProgram) -> ExportedProgram:
"""Create a fake exported program. This uses fake tensors for the state dict
to prevent mutation, and points to the real constants, to avoid large memory
usage from copying when constants are large.

Args:
real_exported_program: the original exported program
Returns:
A new exported program, with fake tensors.
"""
fake_mode = detect_fake_mode(
tuple(
node.meta["val"]
for node in real_exported_program.graph.nodes
if node.op == "placeholder"
)
)
if fake_mode is None:
raise AssertionError(
"Could not detect fake mode for graph: ", real_exported_program.graph
)

new_state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]] = {}

for key, tensor in real_exported_program.state_dict.items():
fake = fake_mode.from_tensor(tensor, static_shapes=True)
new_state_dict[key] = fake

gm = copy.deepcopy(real_exported_program.graph_module)
fake_exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
graph_signature=copy.deepcopy(real_exported_program.graph_signature),
state_dict=new_state_dict,
range_constraints=copy.deepcopy(real_exported_program.range_constraints),
module_call_graph=copy.deepcopy(real_exported_program.module_call_graph),
verifier=real_exported_program.verifier,
constants=real_exported_program.constants,
)
return fake_exported_program


def update_to_real_program(
fake_exported_program: ExportedProgram, real_exported_program: ExportedProgram
) -> None:
"""Update the fake exported program to point to the real state dict. Modifies the
fake exported program in-place.
"""
fake_exported_program._state_dict = real_exported_program.state_dict
1 change: 1 addition & 0 deletions exir/program/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python_unittest(
# @autodeps-skip pybindings don't work well with autodeps
name = "test_program",
srcs = [
"test_fake_program.py",
"test_program.py",
],
deps = [
Expand Down
76 changes: 76 additions & 0 deletions exir/program/test/test_fake_program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import sys
import unittest

import torch

from executorch.exir.program._fake_program import (
get_fake_program,
update_to_real_program,
)
from torch.export import export, ExportedProgram


def get_exported_program() -> ExportedProgram:
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf", torch.randn(10, 10), persistent=False)

def forward(self, arg) -> torch.Tensor:
return self.linear(arg) + self.buf

linear = Linear()
exported_program = export(
linear,
args=(torch.randn(10, 10),),
).run_decompositions()
return exported_program


class TestFakeProgram(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_fake_program(self) -> None:
exported_program = get_exported_program()
fake_program = get_fake_program(exported_program)
print(f"Exported program size: {sys.getsizeof(exported_program.state_dict)}")
print(f"Fake program size: {sys.getsizeof(fake_program.state_dict)}")

# Fake program deep copies attributes besides verifier, state_dict and constants.
self.assertEqual(exported_program.graph_signature, fake_program.graph_signature)
self.assertNotEqual(
id(exported_program.graph_signature), id(fake_program.graph_signature)
)
self.assertEqual(
exported_program.module_call_graph, fake_program.module_call_graph
)
self.assertNotEqual(
id(exported_program.module_call_graph), id(fake_program.module_call_graph)
)

# Verifier is static.
self.assertEqual(exported_program.verifier, fake_program.verifier)
self.assertEqual(id(exported_program.verifier), id(fake_program.verifier))

# Fake program uses fake tensors for the state dict. Size should be smaller.
self.assertLess(
sys.getsizeof(fake_program.state_dict),
sys.getsizeof(exported_program.state_dict),
)

# Do not copy constants.
self.assertEqual(exported_program.constants, fake_program.constants)
self.assertEqual(id(exported_program.constants), id(fake_program.constants))

update_to_real_program(fake_program, exported_program)
self.assertEqual(exported_program.state_dict, fake_program.state_dict)
self.assertEqual(id(exported_program.state_dict), id(fake_program.state_dict))