Skip to content

[ExecuTorch][Weight Sharing] Track Named Data Store in EdgeProgramManager #9151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def to_backend(
backend_id=backend_id,
processed_bytes=preprocess_result.processed_bytes,
compile_specs=compile_specs,
named_data_store_output=preprocess_result.data_store_output,
)
lowered_module.meta = {
"debug_handle_map": preprocess_result.debug_handle_map
Expand Down
7 changes: 7 additions & 0 deletions exir/backend/backend_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from typing import Dict, List, Optional, Tuple, Union

from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram

Expand All @@ -24,6 +26,11 @@ class PreprocessResult:
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
None
)
# Data Store output created from NamedDataStore.

# Named Data store contains all the named data that is stored in the PTE file,
# but retrieveable by delegates via the NamedDataMap at runtime.
data_store_output: Optional[NamedDataStoreOutput] = None


"""
Expand Down
56 changes: 56 additions & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,62 @@ python_library(
],
)

python_library(
name = "backend_with_named_data_map",
srcs = [
"backend_with_named_data_map.py",
],
visibility = [
"//executorch/...",
"//executorch/test/...",
],
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch_src",
"//executorch/exir:delegate",
"//executorch/exir:graph_module",
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/dialects:lib",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)

python_unittest(
name = "test_backend_with_named_data_map",
srcs = [
"test_backend_with_named_data_map.py",
],
visibility = [
"//executorch/...",
"//executorch/test/...",
],
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch_src",
"//executorch/exir:delegate",
"//executorch/exir:graph_module",
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/dialects:lib",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
":backend_with_named_data_map",
],
)

python_library(
name = "qnn_backend_demo",
srcs = [
Expand Down
115 changes: 115 additions & 0 deletions exir/backend/test/backend_with_named_data_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, final, List, Tuple

import torch
from executorch.exir._serialize._named_data_store import NamedDataStore

from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)

from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase


# Backend details are final (cannot be subclassed).
@final
class BackendWithNamedDataMap(BackendDetails):
"""
Test Backend for Named Data Map Functionality

This backend returns no processed_bytes, instead it uses
the named data store and serializes the name of the op
as the key and the data as its code value
"""

@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
op_codes = {
exir_ops.edge.aten.sin.default: 0,
exir_ops.edge.aten.add.Tensor: 1,
exir_ops.edge.aten.sub.Tensor: 2,
exir_ops.edge.aten.mul.Tensor: 3,
exir_ops.edge.aten.div.Tensor: 4,
}
ndm = NamedDataStore()
for node in edge_program.graph.nodes:
if node.op == "call_function":
if node.target in op_codes.keys():
ndm.add_named_data(
node.target.__name__, bytes(op_codes[node.target])
)

return PreprocessResult(
processed_bytes=bytes(b""),
debug_handle_map={},
data_store_output=ndm.get_named_data_store_output(),
)


class SimpleOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
]


@final
class BackendWithNDMPartitioner(Partitioner):
def __init__(self) -> None:
self._op_support = SimpleOperatorSupport()
self.backend_id = BackendWithNamedDataMap.__name__

def _partition_gm(
self, graph_module: torch.fx.GraphModule, id_start: int = 0
) -> Tuple[int, Dict[str, DelegationSpec]]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self._op_support
)

num_partitions_in_gm = len(partition_list)
for partition in partition_list:
curr_par_id = partition.id or 0
delegation_tag = f"tag_{curr_par_id + id_start}"
for node in partition.nodes:
node.meta["delegation_tag"] = delegation_tag
delegation_spec = DelegationSpec(self.backend_id, [])
partition_tags[delegation_tag] = delegation_spec

start_idx_for_submodules = num_partitions_in_gm
for _, submodule, _ in get_control_flow_submodules(graph_module):
start_idx_for_submodules, ret_partition_tags = self._partition_gm(
submodule, start_idx_for_submodules
)
partition_tags.update(ret_partition_tags)

return start_idx_for_submodules, partition_tags

def partition(self, edge_program: ExportedProgram) -> PartitionResult:
_, partition_tags = self._partition_gm(edge_program.graph_module)
return PartitionResult(
tagged_exported_program=edge_program,
partition_tags=partition_tags,
)
83 changes: 83 additions & 0 deletions exir/backend/test/test_backend_with_named_data_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.exir import to_edge
from executorch.exir.backend.backend_api import to_backend

from executorch.exir.backend.test.backend_with_named_data_map import (
BackendWithNamedDataMap,
BackendWithNDMPartitioner,
)


class TestBackendWithNamedDataMap(unittest.TestCase):
def test_lowered_backend_module_has_output(self):
class M(torch.nn.Module):
def forward(self, x):
return x + x

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
lowered = to_backend(
BackendWithNamedDataMap.__name__, ep.exported_program(), []
)

buffer_entries = lowered.named_data_store_output.buffers
self.assertTrue(len(buffer_entries) == 1)
stored_data = lowered.named_data_store_output.pte_data

self.assertTrue("aten.add.Tensor" in stored_data)
self.assertTrue(buffer_entries[0].buffer == bytes(1))

def test_named_data_with_partitioner(self):
class M(torch.nn.Module):
def forward(self, x):
y = x + x
y = torch.cos(y)
y = y + y
y = torch.sin(y)
return y - y

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
ep.to_backend(BackendWithNDMPartitioner())

ndm_output = ep._named_data_store.get_named_data_store_output()
buffer_entries = ndm_output.buffers
stored_data = ndm_output.pte_data
self.assertEqual(len(buffer_entries), 3)
self.assertTrue("aten.add.Tensor" in stored_data)
self.assertTrue("aten.sub.Tensor" in stored_data)
self.assertTrue("aten.sin.default" in stored_data)

def test_named_data_with_control_flow(self):
class M(torch.nn.Module):
def true_branch(self, x):
y = x * x
y = torch.cos(y)
return torch.sin(y)

def false_branch(self, x):
return torch.sin(x)

def forward(self, x, y):
z = x / y
z = torch.cond(z.sum() > 0, self.true_branch, self.false_branch, [x])
return z - z

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2), torch.randn(1, 2))))
ep.to_backend(BackendWithNDMPartitioner())

ndm_output = ep._named_data_store.get_named_data_store_output()
buffer_entries = ndm_output.buffers
stored_data = ndm_output.pte_data
self.assertEqual(len(buffer_entries), 4)
self.assertTrue("aten.sub.Tensor" in stored_data)
self.assertTrue("aten.div.Tensor" in stored_data)
self.assertTrue("aten.sin.default" in stored_data)
self.assertTrue("aten.mul.Tensor" in stored_data)
14 changes: 14 additions & 0 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program
Expand Down Expand Up @@ -62,19 +63,24 @@ class LoweredBackendModule(torch.nn.Module):
CompileSpec
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
_original_exported_program: ExportedProgram # The original EXIR module
_named_data_store_output: Optional[
NamedDataStoreOutput
] # Named Data serialized by the backend

def __init__(
self,
edge_program: ExportedProgram,
backend_id: str,
processed_bytes: bytes,
compile_specs: List[CompileSpec],
named_data_store_output: Optional[NamedDataStoreOutput] = None,
) -> None:
super().__init__()
self._original_exported_program = edge_program
self._backend_id = backend_id
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
self._named_data_store_output = named_data_store_output

# pyre-ignore
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
Expand Down Expand Up @@ -134,6 +140,13 @@ def original_module(self) -> ExportedProgram:
"""
return self._original_exported_program

@property
def named_data_store_output(self) -> Optional[NamedDataStoreOutput]:
"""
Returns the Named Data Store Output
"""
return self._named_data_store_output

# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
def buffer(
self,
Expand All @@ -154,6 +167,7 @@ def buffer(
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
named_data=self.named_data_store_output,
)
)
return out
Expand Down
Loading
Loading