Skip to content

Commit

Permalink
Refactor delegation code
Browse files Browse the repository at this point in the history
Differential Revision: D60813405

Pull Request resolved: pytorch#4566
  • Loading branch information
angelayi authored Aug 15, 2024
1 parent ae299cf commit 3e4508a
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 185 deletions.
119 changes: 60 additions & 59 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import copy
import logging
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import singledispatch
from typing import Generator, List

Expand All @@ -25,12 +25,11 @@

from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import (
_get_new_signature,
_unsafe_adjust_original_program,
create_exported_program_from_submodule,
create_submodule_from_nodes,
LoweredBackendModule,
)
from executorch.exir.pass_base import ExportPass
from executorch.exir.program._fake_program import (
get_fake_program,
update_to_real_program,
Expand Down Expand Up @@ -193,6 +192,7 @@ def _partition_and_lower_one_graph_module(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool,
) -> torch.fx.GraphModule:
"""
Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module.
Expand All @@ -210,21 +210,40 @@ def _partition_and_lower_one_graph_module(

logging.debug(f"For tag {tag}, found nodes {node_list}")
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
submodule, call_module_node = create_submodule_from_nodes(
tagged_graph_module, node_list, tag

replace_ctx = (
tagged_graph_module._set_replace_hook(
owning_program.graph_signature.get_replace_hook()
)
if not is_submodule
else nullcontext()
)
with replace_ctx:
submodule, call_module_node = create_submodule_from_nodes(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = [
node for node in tagged_graph_module.graph.nodes if node.op == "output"
]
][0]
submodule_output_node = [
node for node in submodule.graph.nodes if node.op == "output"
]
# Copy the output node meta from the original output node, because create_submodule_from_nodes doesn't cover the meta field
submodule_output_node[0].meta = tagged_graph_module_output_node[0].meta
][0]
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
logging.debug(f"Partitioned graph module: {tagged_graph_module}")

submodule_program = create_exported_program_from_submodule(
submodule, owning_program, tag
(
submodule_program,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
) = create_exported_program_from_submodule(
submodule,
owning_program,
tag,
call_module_node,
is_submodule,
)

lowered_submodule = to_backend(
Expand Down Expand Up @@ -257,64 +276,48 @@ def _partition_and_lower_one_graph_module(
call_delegate_node.meta["debug_handle"] = len(
tagged_graph_module.graph.nodes
)
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
call_module_node.replace_all_uses_with(call_delegate_node)
tagged_graph_module.graph.erase_node(call_module_node)

# Delete all parameters/buffers consumed by the created exported program
toplevel_signature = owning_program.graph_signature
for node in tagged_graph_module.graph.nodes:
# Find placeholders consumed by the delegate
if node.op != "placeholder" or len(node.users) != 0:
continue

if node.name in toplevel_signature.inputs_to_buffers:
# Delete the consumed buffers
buffer_name = toplevel_signature.inputs_to_buffers.get(node.name)
if buffer_name in owning_program.state_dict:
owning_program.state_dict.pop(buffer_name)
else:
owning_program.constants.pop(buffer_name)
tagged_graph_module.graph.erase_node(node)
elif node.name in toplevel_signature.inputs_to_parameters:
# Delete the consumed parameters
param_name = toplevel_signature.inputs_to_parameters.get(node.name)
owning_program.state_dict.pop(param_name)
tagged_graph_module.graph.erase_node(node)

tagged_graph_module.recompile()
if is_submodule:
assert len(toplevel_input_specs_to_delete) == 0
assert len(toplevel_output_specs_to_delete) == 0
elif (
len(toplevel_input_specs_to_delete) > 0
or len(toplevel_output_specs_to_delete) > 0
):
_unsafe_adjust_original_program(
owning_program,
call_delegate_node,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)

return tagged_graph_module


def _partition_and_lower(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool = False,
) -> torch.fx.GraphModule:
"""
Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow
"""

partitioned_module = _partition_and_lower_one_graph_module(
tagged_graph_module, partition_result, owning_program
tagged_graph_module, partition_result, owning_program, is_submodule
)

# Recursively partition and lower for submodules
for name, submod, _node in get_control_flow_submodules(partitioned_module):
partitioned_submodule = _partition_and_lower(
submod, partition_result, owning_program
submod, partition_result, owning_program, is_submodule=True
)
tagged_graph_module.add_module(name, partitioned_submodule)

# Run the export pass over the graph module so that the call delegate
# nodes will match Edge dialect
# TODO(angelayi): ExportPass will rerun the graph, however all we need
# here is to add metadata to the call delegate nodes to preserve Edge
# dialect. There's work going on to generate a random tensor from a
# fake tensor and possibly it can help to address the issue.
res = ExportPass()(tagged_graph_module)
assert res is not None
tagged_graph_module = res.graph_module

return tagged_graph_module


Expand Down Expand Up @@ -349,6 +352,8 @@ def to_backend(
Returns:
ExportedProgram: The input program, with some portions targeted for delegation.
"""
edge_program._validate()

# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
try:
Expand Down Expand Up @@ -377,26 +382,22 @@ def to_backend(
update_to_real_program(tagged_exported_program, edge_program)

for tag, _ in partitioner_result.partition_tags.items():
_maybe_duplicate_constant_nodes(tagged_exported_program, tag, edge_program)
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)

tagged_graph_module = _partition_and_lower(
tagged_exported_program.graph_module, partitioner_result, edge_program
tagged_exported_program.graph_module,
partitioner_result,
tagged_exported_program,
)

# TODO(angelayi): Update this signature in a less manual way (maybe through
# retracing)
new_signature, new_state_dict, new_constants = _get_new_signature(
edge_program,
tagged_graph_module,
)
return ExportedProgram(
root=tagged_graph_module,
graph=tagged_graph_module.graph,
graph_signature=new_signature,
state_dict=new_state_dict,
range_constraints=copy.deepcopy(edge_program.range_constraints),
module_call_graph=copy.deepcopy(edge_program.module_call_graph),
graph_signature=tagged_exported_program.graph_signature,
state_dict=tagged_exported_program.state_dict,
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
example_inputs=None,
constants=new_constants,
verifiers=[edge_program.verifier],
constants=tagged_exported_program.constants,
verifiers=[tagged_exported_program.verifier],
)
24 changes: 1 addition & 23 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.lowered_backend_module import (
_get_new_signature,
get_lowered_submodules,
)
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.print_program import print_program
from executorch.exir.schema import (
BackendDelegate,
Expand All @@ -63,7 +60,6 @@
prepare_fx,
)
from torch.export import ExportedProgram
from torch.export.exported_program import OutputKind, TensorArgument
from torch.testing import FileCheck


Expand Down Expand Up @@ -1270,21 +1266,3 @@ def forward(self, x: List[torch.Tensor]):

gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
gm(*inputs)

def test_get_new_signature(self):
class MyModule(torch.nn.Module):
def forward(self, x, y, z):
return x + y, y - z, z * x

ep = torch.export.export(
MyModule(), (torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2))
)
sig, *_ = _get_new_signature(ep, ep.graph_module)
output_names = set()
self.assertEqual(len(sig.output_specs), 3)
for s in sig.output_specs:
self.assertEqual(s.kind, OutputKind.USER_OUTPUT)
self.assertIsInstance(s.arg, TensorArgument)
name = s.arg.name
self.assertNotIn(name, output_names)
output_names.add(name)
2 changes: 0 additions & 2 deletions exir/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def _assign_new_tag(
def _maybe_duplicate_constant_nodes(
tagged_exported_program: ExportedProgram,
tag: str,
owning_program: ExportedProgram,
) -> None:
"""
If the constants node is shared by different tagged nodes, like
Expand Down Expand Up @@ -241,7 +240,6 @@ def _maybe_duplicate_constant_nodes(
copied_nodes = copied_nodes.union(
duplicate_constant_node(tagged_exported_program, candidate_node)
)
duplicate_constant_node(owning_program, candidate_node)
candidate_node_with_copies = candidate_nodes.union(copied_nodes)
_assign_new_tag(tagged_exported_program, candidate_node_with_copies)

Expand Down
Loading

0 comments on commit 3e4508a

Please sign in to comment.