Skip to content

[ExecuTorch][to_backend] Enable to_backend API to leverage preprocess_all #9811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 331 additions & 1 deletion exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import logging
from contextlib import contextmanager, nullcontext
from functools import singledispatch
from typing import Generator, List
from typing import Generator, List, Dict
from dataclasses import dataclass

import torch

Expand Down Expand Up @@ -417,3 +418,332 @@ def to_backend(
constants=tagged_exported_program.constants,
verifiers=[tagged_exported_program.verifier],
)


def _create_partitions_in_graph_module(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool,
) -> Dict[str, List[torch.fx.Node]]:
backend_id_to_submodule_name = {}
for tag, delegation_spec in partition_result.partition_tags.items():
# Create partition with nodes containing this tag. There should only be
# one contained submodule per tag
node_list = _get_node_list_with_same_tag(
tagged_graph_module, tag, owning_program
)

if len(node_list) == 0:
logging.debug(f"Did not find any nodes for tag {tag}")
continue

logging.debug(f"For tag {tag}, found nodes {node_list}")
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)

replace_ctx = (
tagged_graph_module._set_replace_hook(
owning_program.graph_signature.get_replace_hook()
)
if not is_submodule
else nullcontext()
)
with replace_ctx:
submodule, call_module_node = create_submodule_from_nodes(
tagged_graph_module, node_list, tag
)

tagged_graph_module_output_node = [
node for node in tagged_graph_module.graph.nodes if node.op == "output"
][0]
submodule_output_node = [
node for node in submodule.graph.nodes if node.op == "output"
][0]
# Copy the output node meta from the original output node, because
# create_submodule_from_nodes doesn't cover the meta field
submodule_output_node.meta = tagged_graph_module_output_node.meta
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
(
submodule_program,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
) = create_exported_program_from_submodule(
submodule,
owning_program,
tag,
call_module_node,
is_submodule,
)
call_module_node.meta["backend_id"] = delegation_spec.backend_id
call_module_node.meta["compile_spec"] = delegation_spec.compile_specs
call_module_node.meta["submodule_program"] = submodule_program
call_module_node.meta["toplevel_input_specs_to_delete"] = toplevel_input_specs_to_delete
call_module_node.meta["toplevel_output_specs_to_delete"] = toplevel_output_specs_to_delete
call_module_node.meta["is_submodule"] = is_submodule

if delegation_spec.backend_id not in backend_id_to_submodule_name:
backend_id_to_submodule_name[delegation_spec.backend_id] = []

# The call_module_node created here might not be the same node instance as
# the one in the final graph module. This is because this node might be replaced
# in future edits to the graph. As a result, we just keep track of the node's name
# and at the end we search for this node in our final graph module
backend_id_to_submodule_name[delegation_spec.backend_id].append(call_module_node.target)

created_submodule_nodes = dict((key,[]) for key in backend_id_to_submodule_name.keys())
for backend_id, submodule_name in backend_id_to_submodule_name.items():
for node in tagged_graph_module.graph.nodes:
if node.op == "call_module" and node.target in submodule_name:
created_submodule_nodes[backend_id].append(node)

# check the number of submodule_names and submodule_nodes are equal
for backend_id in created_submodule_nodes.keys():
assert len(created_submodule_nodes[backend_id]) == len(backend_id_to_submodule_name[backend_id])

return created_submodule_nodes

def _create_partitions(
tagged_graph_module: torch.fx.GraphModule,
partition_result: PartitionResult,
owning_program: ExportedProgram,
is_submodule: bool = False,
) -> Dict[str, List[torch.fx.Node]]:
backend_id_to_call_submodules = _create_partitions_in_graph_module(
tagged_graph_module, partition_result, owning_program, is_submodule
)

# Recursively partition and lower for submodules
for _, submod, _ in get_control_flow_submodules(tagged_graph_module):
nested_backend_id_to_call_submodules = _create_partitions(
submod, partition_result, owning_program, is_submodule=True
)
for backend_id, nested_submodules in nested_backend_id_to_call_submodules.items():
if backend_id not in backend_id_to_call_submodules:
backend_id_to_call_submodules[backend_id] = nested_submodules
else:
backend_id_to_call_submodules[backend_id].extend(nested_submodules)

return backend_id_to_call_submodules

def lower_all_submodules_to_backend(
backend_id: str,
method_to_submodules_nodes: Dict[str, List[torch.fx.Node]],
method_to_tagged_edge_program: Dict[str, ExportedProgram],
) -> None:
"""
Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id.
"""
# The created exported program for the submodules are in the call_module node's meta data
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
method_to_partitioned_program = {
method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
}
method_to_compile_specs = {
method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
}
backend_found = False
for cls in BackendDetails.__subclasses__():
if backend_id == cls.__name__:
method_to_preprocess_result: dict[str, List[PreprocessResult]] = cls.preprocess_all(
method_to_partitioned_program,
method_to_compile_specs
)
backend_found = True

if not backend_found:
raise NotImplementedError(f"Backend {backend_id} was not found.")

for method_name in method_to_preprocess_result.keys():
owning_program = method_to_tagged_edge_program[method_name]
list_of_preprocess_results = method_to_preprocess_result[method_name]
list_of_call_submodule_nodes = method_to_submodules_nodes[method_name]
list_of_compile_specs = method_to_compile_specs[method_name]
assert (
len(list_of_preprocess_results) == len(list_of_call_submodule_nodes),
f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}"
)
for preprocess_result, call_submodule_node, compile_spec in zip(list_of_preprocess_results, list_of_call_submodule_nodes, list_of_compile_specs):
submodule_program = call_submodule_node.meta["submodule_program"]
lowered_module = LoweredBackendModule(
edge_program=submodule_program,
backend_id=backend_id,
processed_bytes=preprocess_result.processed_bytes,
compile_specs=compile_spec,
)
owning_graph_module = call_submodule_node.graph.owning_module
is_submodule = call_submodule_node.meta["is_submodule"]
toplevel_input_specs_to_delete = call_submodule_node.meta["toplevel_input_specs_to_delete"]
toplevel_output_specs_to_delete = call_submodule_node.meta["toplevel_output_specs_to_delete"]
# call delegate args should only use user_inputs
call_delegate_args = []
# Preserve input order as user_inputs
for inp_name in submodule_program.graph_signature.user_inputs:
for inp_node in call_submodule_node.all_input_nodes:
if inp_node.name == inp_name:
call_delegate_args.append(inp_node)
break

def generate_debug_handle(ep: ExportedProgram) -> int:
"""
Generate a debug handle for the given ExportedProgram.
"""
debug_handle = 0
for node in ep.graph_module.graph.nodes:
debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))
return debug_handle + 1

# Replace the partitioned submodule with a lowered submodule
# Add call_method node with function "forward"
with owning_graph_module.graph.inserting_before(call_submodule_node):
lowered_name = get_lowered_module_name(
owning_graph_module, lowered_module
)
lowered_node = owning_graph_module.graph.get_attr(lowered_name)
call_delegate_node = owning_graph_module.graph.call_function(
executorch_call_delegate,
(lowered_node,) + tuple(call_delegate_args),
call_submodule_node.kwargs,
)
call_delegate_node.meta["debug_handle"] = generate_debug_handle(
owning_program
)
call_delegate_node.meta["val"] = call_submodule_node.meta["val"]
call_submodule_node.replace_all_uses_with(call_delegate_node)
owning_graph_module.graph.erase_node(call_submodule_node)

if is_submodule:
assert len(toplevel_input_specs_to_delete) == 0
assert len(toplevel_output_specs_to_delete) == 0
elif (
len(toplevel_input_specs_to_delete) > 0
or len(toplevel_output_specs_to_delete) > 0
):
_unsafe_adjust_original_program(
owning_program,
call_delegate_node,
toplevel_input_specs_to_delete,
toplevel_output_specs_to_delete,
)

@dataclass
class MethodProgramsPartitionerSpec:
"""
Since single dispatch for to_backend requires the first argument to be a
valid class, we create the following dataclass spec to hold the dictionaries
mapping the method name to the corresponding program, partitioner
"""
method_to_edge_program: Dict[str, ExportedProgram]
method_to_partitioner: Dict[str, Partitioner]

@to_backend.register
def _(
method_edge_program_partitioners: MethodProgramsPartitionerSpec
) -> Dict[str, ExportedProgram]:
"""
Add overloaded implementations for to_backend:

::

def to_backend(
method_edge_program_partitioners: MethodProgramsPartitionerSpec
) -> Dict[str, ExportedProgram]:

Returns a semantically-equivalent dictionary of programs to the programs given as input (represented
as a graph module in Edge dialect), but with portions of the program targeted for
delegation as determined by the partitioner.

Args:
method_edge_program_partitioners: contains two mappings,
- method_to_edge_program: mapping of method names to their respective programs in Edge dialect.
- method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging
portions of the specified program for delegation. A valid partitioner must return PartitionerResult
including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and
the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.


Returns:
ExportedProgram: The input program, with some portions targeted for delegation.
"""
method_to_edge_program = method_edge_program_partitioners.method_to_edge_program
method_to_partitioner = method_edge_program_partitioners.method_to_partitioner

partitioned_and_lowered_exported_programs = {}
backend_id_to_method_submodules_map = {}
method_to_tagged_exported_program = {}

for method_name, partitioner_instance in method_to_partitioner.items():
assert (
method_name in method_to_edge_program
), f"Partitioner for method {method_name} is not provided"
edge_program = method_to_edge_program[method_name]
edge_program._validate()

# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
try:
fake_edge_program = get_fake_program(edge_program)
except Exception as e:
logging.warning(
f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"
)
fake_edge_program = copy.deepcopy(edge_program)
partitioner_result = partitioner_instance(fake_edge_program)
tagged_exported_program = partitioner_result.tagged_exported_program
method_to_tagged_exported_program[method_name] = tagged_exported_program

# Check that the partitioner did not modify the original graph
if _ENABLE_VALIDATION:
assert is_identical_graph(
tagged_exported_program.graph_module,
edge_program.graph_module,
), f"The partitioner {partitioner_instance} should not modify the graph module"
else:
logging.warning("Disabled validating the partitioner.")

assert (
partitioner_result.partition_tags is not None
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"

update_to_real_program(tagged_exported_program, edge_program)

for tag, _ in partitioner_result.partition_tags.items():
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)

backend_id_to_call_submodule_nodes = _create_partitions(
tagged_exported_program.graph_module,
partitioner_result,
tagged_exported_program,
)
for backend_id, call_submodule_nodes in backend_id_to_call_submodule_nodes.items():
if backend_id not in backend_id_to_method_submodules_map:
backend_id_to_method_submodules_map[backend_id] = {}
backend_id_to_method_submodules_map[backend_id][method_name] = call_submodule_nodes

for backend_id, method_to_submodule_nodes in backend_id_to_method_submodules_map.items():
lower_all_submodules_to_backend(
backend_id,
method_to_submodule_nodes,
method_to_tagged_exported_program,
)

for method_name in method_to_edge_program.keys():
if method_name in method_to_tagged_exported_program:
tagged_exported_program = method_to_tagged_exported_program[method_name]
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
root=tagged_exported_program.graph_module,
graph=tagged_exported_program.graph_module.graph,
graph_signature=tagged_exported_program.graph_signature,
state_dict=tagged_exported_program.state_dict,
range_constraints=copy.deepcopy(tagged_exported_program.range_constraints),
module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph),
example_inputs=None,
constants=tagged_exported_program.constants,
verifiers=[tagged_exported_program.verifier],
)
else:
# this edge program wasn't partitioned, so we can just return it as is
partitioned_and_lowered_exported_programs[method_name] = method_to_edge_program[method_name]

return partitioned_and_lowered_exported_programs
Loading
Loading