Skip to content

Allow delegate to consume buffer mutations #4830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/apple/mps/test/test_mps_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def lower_module_and_test_output(
compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))]

if use_partitioner:
logging.info(f"Edge IR graph:\n{edge_program.exported_program().graph}")
logging.info(f"Edge IR graph:\n{edge_program.exported_program()}")
delegated_program = edge_program
delegated_program = edge_program.to_backend(
MPSPartitioner(compile_specs=compile_specs)
Expand Down
3 changes: 3 additions & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ python_library(
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
"//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess",
"//executorch/exir/dialects:lib",
],
)
Expand Down Expand Up @@ -290,6 +292,7 @@ python_unittest(
"//executorch/exir/backend/test/demos/rpc:executor_backend_register",
],
deps = [
":op_partitioner_demo",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/backend:backend_details",
Expand Down
50 changes: 50 additions & 0 deletions exir/backend/test/op_partitioner_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
ExecutorBackend,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
Expand All @@ -29,6 +32,11 @@
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase


class AllOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function"


class AddOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
Expand Down Expand Up @@ -126,6 +134,48 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
)


@final
class AllNodesPartitionerDemo(Partitioner):
"""
Partitions all nodes
"""

def __init__(self) -> None:
self.op_support = AllOperatorSupport()
self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, [])

def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult:
partition_tags = {}
partition_list = generate_pattern_op_partitions(
edge_exported_program.graph_module, op_support=self.op_support
)
for partition in partition_list:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
partition_tags[delegation_tag] = self.delegation_spec

# Tag the add nodes
node.meta["delegation_tag"] = delegation_tag

for arg_node in node.args:
if not isinstance(arg_node, torch.fx.Node):
continue

is_get_attr = arg_node.op == "get_attr"
is_param_buffer = arg_node.op == "placeholder" and (
is_param(edge_exported_program, arg_node)
or is_buffer(edge_exported_program, arg_node)
or is_lifted_tensor_constant(edge_exported_program, arg_node)
)
if is_get_attr or is_param_buffer:
arg_node.meta["delegation_tag"] = delegation_tag
# Add to the list of partitioned nodes.

return PartitionResult(
tagged_exported_program=edge_exported_program, partition_tags=partition_tags
)


ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
Expand Down
112 changes: 112 additions & 0 deletions exir/backend/test/test_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import (
ExecutorBackend,
)
from executorch.exir.backend.test.op_partitioner_demo import (
AddAttributePartitionerDemo,
AllNodesPartitionerDemo,
)
from executorch.exir.backend.utils import get_delegates, tag_constant_data

from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -619,3 +623,111 @@ def partition(
and node.target == torch.ops.aten.copy_.default
]
self.assertEqual(len(copy_node), 1)

def test_buffer_mutation1(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("b", torch.ones(3, 3))

def forward(self, x):
self.b.add_(x)
return x + self.b

model_inputs = (torch.ones(3, 3),)
orig_res = TestModule()(*model_inputs)
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs))
lowered = edge_program.to_backend(AddAttributePartitionerDemo())

self.assertTrue(
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res)
)

self.assertEqual(
len(lowered.exported_program().graph_signature.buffers_to_mutate),
0,
)
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]

# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clarification: I thought lowered_module_node itself is call_delegate node? It not what it is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowered_module_node is the getattr node that points to the actual saved lowered module. The call_delegate_node is the node that calls the delegate.

self.assertEqual(len(call_delegate_node.args), 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add comment on what the two args are. I presume the first one is delegate blob and second is the input?
(feel free to ignore)


lower_module = getattr(
lowered.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module

self.assertEqual(len(delegated_ep.state_dict), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)

def test_buffer_mutation_llama_repro(self):
SHAPE = (2, 3)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32))

def forward(self, q, k_val, input_pos):
q_T = q.transpose(0, 1)
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val)
attn = k.mm(q_T)
return attn

q = torch.rand(1, 3)
k = torch.rand(1, 3)
example_inputs = (q, k, torch.tensor([1, 1]))

model = Model()
model.eval()

exir_program_aten = torch.export.export(model, example_inputs)
exir_program_aten.module()(*example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten)
lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo())

self.assertEqual(
len(lowered.exported_program().graph_signature.buffers_to_mutate),
0,
)
lowered_module_nodes = get_delegates(lowered.exported_program().graph)
self.assertEqual(len(lowered_module_nodes), 1)
lowered_module_node = lowered_module_nodes[0]

# get call delegate node
call_delegate_node = list(lowered_module_node.users.keys())[0]
self.assertEqual(len(call_delegate_node.args), 4)

lower_module = getattr(
lowered.exported_program().graph_module, lowered_module_node.name
)
delegated_ep = lower_module.original_module

self.assertEqual(len(delegated_ep.state_dict), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1)
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1)

def test_buffer_mutation_unsupported(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks for this test case

SHAPE = (2, 3)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32))

def forward(self, x):
add = self.state_1.add_(x)
return add

model = Model()
model.eval()

example_inputs = (torch.randn(SHAPE),)
exir_program_aten = torch.export.export(model, example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten)
with self.assertRaises(AssertionError):
edge_program_manager.to_backend(AddAttributePartitionerDemo())
128 changes: 122 additions & 6 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import copy
import operator
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901
else {}
)

toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list)
if not is_submodule:
for output_spec in old_signature.output_specs:
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec)

for node in gm.graph.nodes:
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag
if node.op == "placeholder":

if node.name not in input_node_to_sig:
Expand All @@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901
if not isinstance(orig_input_spec.arg, TensorArgument):
input_specs.append(orig_input_spec)

elif is_tagged:
elif node.meta.get("delegation_tag", None) == tag:
input_specs.append(orig_input_spec)

if orig_input_spec.kind == InputKind.USER_INPUT:
Expand Down Expand Up @@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901
)

if node.op == "output":
output_nodes = pytree.tree_leaves((node.args, node.kwargs))
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list)
for user in call_module_node.users.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis is checking if the partitioned subgraph's call_module node is returning mutated buffer, right? If so we plan to remove those from call signature of the submodule but also from top level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to remove them from the toplevel module, and in the submodule, set the signature as a "buffer mutation".

if user.name in toplevel_output_node_to_sig:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}"
getitem_idx = user.args[1]
assert isinstance(
getitem_idx, int
), f"Invalid getitem type: {type(getitem_idx)}"
buffer_mutation_idxs[getitem_idx].extend(
toplevel_output_node_to_sig[user.name]
)

for output_node in output_nodes:
for i, output_node in enumerate(node.args[0]):
if i in buffer_mutation_idxs:
assert isinstance(output_node, torch.fx.Node)
orig_output_specs = buffer_mutation_idxs[i]

if any(
orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and orig_output_spec.target in new_state_dict
for orig_output_spec in orig_output_specs
):
# If the delegate wants to consume the buffer, then the
# delegate should also consume the buffer mutation
# (output spec would be a BUFFER_MUTATION). Otherwise
# the delegate will just return the result of the
# mutation as a USER_OUTPUT.

orig_output_spec = [
orig_output_spec
for orig_output_spec in orig_output_specs
if orig_output_spec.kind == OutputKind.BUFFER_MUTATION
and orig_output_spec.target in new_state_dict
][0]

assert len(orig_output_specs) == 1, (
f"Constant {orig_output_spec.target} was tagged to be "
"consumed by the buffer, and was found to also contain "
"a buffer mutation. However this buffer mutation node "
"was found to also be used as other types of outputs "
"which is currently not supported. Please file an "
"issue on Github. \n\n"
f"The toplevel program: {original_program}\n"
)
output_specs.append(
OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=TensorArgument(name=output_node.name),
target=orig_output_spec.target,
)
)
output_specs_to_delete[orig_output_spec.arg.name] = (
orig_output_spec
)
else:
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_node.name),
target=None,
)
)

if not isinstance(output_node, torch.fx.Node):
elif not isinstance(output_node, torch.fx.Node):
output_specs.append(
OutputSpec(
kind=OutputKind.USER_OUTPUT,
Expand Down Expand Up @@ -630,6 +696,9 @@ def create_exported_program_from_submodule(
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]

print(submodule.graph)
print(subgraph_signature)

return (
ExportedProgram(
root=submodule,
Expand Down Expand Up @@ -774,7 +843,7 @@ def get_lowered_backend_modules(
return lowered_programs


def _unsafe_adjust_original_program(
def _unsafe_adjust_original_program( # noqa: C901
original_program: ExportedProgram,
call_delegate_node: torch.fx.Node,
input_specs_to_delete: Dict[str, InputSpec],
Expand Down Expand Up @@ -830,3 +899,50 @@ def _unsafe_adjust_original_program(
del original_program._constants[input_spec.target]
else:
raise RuntimeError(f"Invalid input spec {input_spec} received")

# Delete buffer mutations from the output which were consumed by the delegate
toplevel_output_node = None
for node in reversed(original_program.graph.nodes):
if node.op == "output":
toplevel_output_node = node
break

assert toplevel_output_node is not None
assert (
len(toplevel_output_node.args) == 1
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}"

new_output_args = [
arg
for arg in toplevel_output_node.args[0]
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete
]
toplevel_output_node.args = (tuple(new_output_args),)

# Delete the buffer mutation getitem nodes
getitem_idxs: List[int] = []
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
if user.name in output_specs_to_delete:
assert (
user.op == "call_function" and user.target == operator.getitem
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}"
user_idx = user.args[1]
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}"
getitem_idxs.append(user_idx)
original_program.graph.erase_node(user)

getitem_idxs.sort(reverse=True)

# Adjust all the getitem indices after the deleted getitems
user_nodes = list(call_delegate_node.users.keys())
for user in user_nodes:
assert user.op == "call_function" and user.target == operator.getitem
user_idx = user.args[1]
assert isinstance(user_idx, int)
for i, idx in enumerate(getitem_idxs):
if user_idx > idx:
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
break

original_program._validate()
Loading
Loading