Skip to content

Commit a33053f

Browse files
pytorchbotmcr229
authored andcommitted
[ExecuTorch][Weight Sharing] Track Named Data Store in EdgeProgramManager (pytorch#9293)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#9151 by @mcr229 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/mcr229/7/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/7/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/mcr229/7/orig @diff-train-skip-merge Co-authored-by: Max Ren <maxren@meta.com>
1 parent 8a5a62b commit a33053f

7 files changed

+301
-2
lines changed

exir/backend/backend_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def to_backend(
120120
backend_id=backend_id,
121121
processed_bytes=preprocess_result.processed_bytes,
122122
compile_specs=compile_specs,
123+
named_data_store_output=preprocess_result.data_store_output,
123124
)
124125
lowered_module.meta = {
125126
"debug_handle_map": preprocess_result.debug_handle_map

exir/backend/backend_details.py

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from typing import Dict, List, Optional, Tuple, Union
1111

12+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
13+
1214
from executorch.exir.backend.compile_spec_schema import CompileSpec
1315
from torch.export.exported_program import ExportedProgram
1416

@@ -24,6 +26,11 @@ class PreprocessResult:
2426
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
2527
None
2628
)
29+
# Data Store output created from NamedDataStore.
30+
31+
# Named Data store contains all the named data that is stored in the PTE file,
32+
# but retrieveable by delegates via the NamedDataMap at runtime.
33+
data_store_output: Optional[NamedDataStoreOutput] = None
2734

2835

2936
"""

exir/backend/test/TARGETS

+56
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,62 @@ python_library(
3838
],
3939
)
4040

41+
python_library(
42+
name = "backend_with_named_data_map",
43+
srcs = [
44+
"backend_with_named_data_map.py",
45+
],
46+
visibility = [
47+
"//executorch/...",
48+
"//executorch/test/...",
49+
],
50+
deps = [
51+
"//caffe2:torch",
52+
"//caffe2/functorch:functorch_src",
53+
"//executorch/exir:delegate",
54+
"//executorch/exir:graph_module",
55+
"//executorch/exir:lib",
56+
"//executorch/exir:lowered_backend_module",
57+
"//executorch/exir:print_program",
58+
"//executorch/exir:schema",
59+
"//executorch/exir/backend:backend_api",
60+
"//executorch/exir/backend:compile_spec_schema",
61+
"//executorch/exir/backend:partitioner",
62+
"//executorch/exir/dialects:lib",
63+
"//executorch/extension/pybindings:portable_lib", # @manual
64+
"//executorch/extension/pytree:pylib",
65+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
66+
],
67+
)
68+
69+
python_unittest(
70+
name = "test_backend_with_named_data_map",
71+
srcs = [
72+
"test_backend_with_named_data_map.py",
73+
],
74+
visibility = [
75+
"//executorch/...",
76+
"//executorch/test/...",
77+
],
78+
deps = [
79+
"//caffe2:torch",
80+
"//caffe2/functorch:functorch_src",
81+
"//executorch/exir:delegate",
82+
"//executorch/exir:graph_module",
83+
"//executorch/exir:lib",
84+
"//executorch/exir:lowered_backend_module",
85+
"//executorch/exir:print_program",
86+
"//executorch/exir:schema",
87+
"//executorch/exir/backend:backend_api",
88+
"//executorch/exir/backend:compile_spec_schema",
89+
"//executorch/exir/backend:partitioner",
90+
"//executorch/exir/dialects:lib",
91+
"//executorch/extension/pybindings:portable_lib", # @manual
92+
"//executorch/extension/pytree:pylib",
93+
":backend_with_named_data_map",
94+
],
95+
)
96+
4197
python_library(
4298
name = "qnn_backend_demo",
4399
srcs = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, final, List, Tuple
8+
9+
import torch
10+
from executorch.exir._serialize._named_data_store import NamedDataStore
11+
12+
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
13+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
14+
generate_pattern_op_partitions,
15+
)
16+
17+
from executorch.exir.backend.compile_spec_schema import CompileSpec
18+
from executorch.exir.backend.partitioner import (
19+
DelegationSpec,
20+
Partitioner,
21+
PartitionResult,
22+
)
23+
from executorch.exir.dialects._ops import ops as exir_ops
24+
from executorch.exir.graph_module import get_control_flow_submodules
25+
from torch.export.exported_program import ExportedProgram
26+
from torch.fx.passes.operator_support import OperatorSupportBase
27+
28+
29+
# Backend details are final (cannot be subclassed).
30+
@final
31+
class BackendWithNamedDataMap(BackendDetails):
32+
"""
33+
Test Backend for Named Data Map Functionality
34+
35+
This backend returns no processed_bytes, instead it uses
36+
the named data store and serializes the name of the op
37+
as the key and the data as its code value
38+
"""
39+
40+
@staticmethod
41+
def preprocess(
42+
edge_program: ExportedProgram,
43+
compile_specs: List[CompileSpec],
44+
) -> PreprocessResult:
45+
op_codes = {
46+
exir_ops.edge.aten.sin.default: 0,
47+
exir_ops.edge.aten.add.Tensor: 1,
48+
exir_ops.edge.aten.sub.Tensor: 2,
49+
exir_ops.edge.aten.mul.Tensor: 3,
50+
exir_ops.edge.aten.div.Tensor: 4,
51+
}
52+
ndm = NamedDataStore()
53+
for node in edge_program.graph.nodes:
54+
if node.op == "call_function":
55+
if node.target in op_codes.keys():
56+
ndm.add_named_data(
57+
node.target.__name__, bytes(op_codes[node.target])
58+
)
59+
60+
return PreprocessResult(
61+
processed_bytes=bytes(b""),
62+
debug_handle_map={},
63+
data_store_output=ndm.get_named_data_store_output(),
64+
)
65+
66+
67+
class SimpleOperatorSupport(OperatorSupportBase):
68+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
69+
return node.op == "call_function" and node.target in [
70+
exir_ops.edge.aten.sin.default,
71+
exir_ops.edge.aten.add.Tensor,
72+
exir_ops.edge.aten.sub.Tensor,
73+
exir_ops.edge.aten.mul.Tensor,
74+
exir_ops.edge.aten.div.Tensor,
75+
]
76+
77+
78+
@final
79+
class BackendWithNDMPartitioner(Partitioner):
80+
def __init__(self) -> None:
81+
self._op_support = SimpleOperatorSupport()
82+
self.backend_id = BackendWithNamedDataMap.__name__
83+
84+
def _partition_gm(
85+
self, graph_module: torch.fx.GraphModule, id_start: int = 0
86+
) -> Tuple[int, Dict[str, DelegationSpec]]:
87+
partition_tags: Dict[str, DelegationSpec] = {}
88+
partition_list = generate_pattern_op_partitions(
89+
graph_module, op_support=self._op_support
90+
)
91+
92+
num_partitions_in_gm = len(partition_list)
93+
for partition in partition_list:
94+
curr_par_id = partition.id or 0
95+
delegation_tag = f"tag_{curr_par_id + id_start}"
96+
for node in partition.nodes:
97+
node.meta["delegation_tag"] = delegation_tag
98+
delegation_spec = DelegationSpec(self.backend_id, [])
99+
partition_tags[delegation_tag] = delegation_spec
100+
101+
start_idx_for_submodules = num_partitions_in_gm
102+
for _, submodule, _ in get_control_flow_submodules(graph_module):
103+
start_idx_for_submodules, ret_partition_tags = self._partition_gm(
104+
submodule, start_idx_for_submodules
105+
)
106+
partition_tags.update(ret_partition_tags)
107+
108+
return start_idx_for_submodules, partition_tags
109+
110+
def partition(self, edge_program: ExportedProgram) -> PartitionResult:
111+
_, partition_tags = self._partition_gm(edge_program.graph_module)
112+
return PartitionResult(
113+
tagged_exported_program=edge_program,
114+
partition_tags=partition_tags,
115+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from executorch.exir import to_edge
12+
from executorch.exir.backend.backend_api import to_backend
13+
14+
from executorch.exir.backend.test.backend_with_named_data_map import (
15+
BackendWithNamedDataMap,
16+
BackendWithNDMPartitioner,
17+
)
18+
19+
20+
class TestBackendWithNamedDataMap(unittest.TestCase):
21+
def test_lowered_backend_module_has_output(self):
22+
class M(torch.nn.Module):
23+
def forward(self, x):
24+
return x + x
25+
26+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
27+
lowered = to_backend(
28+
BackendWithNamedDataMap.__name__, ep.exported_program(), []
29+
)
30+
31+
buffer_entries = lowered.named_data_store_output.buffers
32+
self.assertTrue(len(buffer_entries) == 1)
33+
stored_data = lowered.named_data_store_output.pte_data
34+
35+
self.assertTrue("aten.add.Tensor" in stored_data)
36+
self.assertTrue(buffer_entries[0].buffer == bytes(1))
37+
38+
def test_named_data_with_partitioner(self):
39+
class M(torch.nn.Module):
40+
def forward(self, x):
41+
y = x + x
42+
y = torch.cos(y)
43+
y = y + y
44+
y = torch.sin(y)
45+
return y - y
46+
47+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
48+
ep.to_backend(BackendWithNDMPartitioner())
49+
50+
ndm_output = ep._named_data_store.get_named_data_store_output()
51+
buffer_entries = ndm_output.buffers
52+
stored_data = ndm_output.pte_data
53+
self.assertEqual(len(buffer_entries), 3)
54+
self.assertTrue("aten.add.Tensor" in stored_data)
55+
self.assertTrue("aten.sub.Tensor" in stored_data)
56+
self.assertTrue("aten.sin.default" in stored_data)
57+
58+
def test_named_data_with_control_flow(self):
59+
class M(torch.nn.Module):
60+
def true_branch(self, x):
61+
y = x * x
62+
y = torch.cos(y)
63+
return torch.sin(y)
64+
65+
def false_branch(self, x):
66+
return torch.sin(x)
67+
68+
def forward(self, x, y):
69+
z = x / y
70+
z = torch.cond(z.sum() > 0, self.true_branch, self.false_branch, [x])
71+
return z - z
72+
73+
ep = to_edge(torch.export.export(M(), (torch.randn(1, 2), torch.randn(1, 2))))
74+
ep.to_backend(BackendWithNDMPartitioner())
75+
76+
ndm_output = ep._named_data_store.get_named_data_store_output()
77+
buffer_entries = ndm_output.buffers
78+
stored_data = ndm_output.pte_data
79+
self.assertEqual(len(buffer_entries), 4)
80+
self.assertTrue("aten.sub.Tensor" in stored_data)
81+
self.assertTrue("aten.div.Tensor" in stored_data)
82+
self.assertTrue("aten.sin.default" in stored_data)
83+
self.assertTrue("aten.mul.Tensor" in stored_data)

exir/lowered_backend_module.py

+14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.utils._pytree as pytree
1616
from executorch.exir._serialize import _serialize_pte_binary
17+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1718
from executorch.exir.backend.compile_spec_schema import CompileSpec
1819
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
1920
from executorch.exir.emit import emit_program
@@ -62,19 +63,24 @@ class LoweredBackendModule(torch.nn.Module):
6263
CompileSpec
6364
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
6465
_original_exported_program: ExportedProgram # The original EXIR module
66+
_named_data_store_output: Optional[
67+
NamedDataStoreOutput
68+
] # Named Data serialized by the backend
6569

6670
def __init__(
6771
self,
6872
edge_program: ExportedProgram,
6973
backend_id: str,
7074
processed_bytes: bytes,
7175
compile_specs: List[CompileSpec],
76+
named_data_store_output: Optional[NamedDataStoreOutput] = None,
7277
) -> None:
7378
super().__init__()
7479
self._original_exported_program = edge_program
7580
self._backend_id = backend_id
7681
self._processed_bytes = processed_bytes
7782
self._compile_specs = compile_specs
83+
self._named_data_store_output = named_data_store_output
7884

7985
# pyre-ignore
8086
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
@@ -134,6 +140,13 @@ def original_module(self) -> ExportedProgram:
134140
"""
135141
return self._original_exported_program
136142

143+
@property
144+
def named_data_store_output(self) -> Optional[NamedDataStoreOutput]:
145+
"""
146+
Returns the Named Data Store Output
147+
"""
148+
return self._named_data_store_output
149+
137150
# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
138151
def buffer(
139152
self,
@@ -154,6 +167,7 @@ def buffer(
154167
segment_alignment=segment_alignment,
155168
constant_tensor_alignment=constant_tensor_alignment,
156169
delegate_alignment=delegate_alignment,
170+
named_data=self.named_data_store_output,
157171
)
158172
)
159173
return out

0 commit comments

Comments
 (0)