-
Notifications
You must be signed in to change notification settings - Fork 607
Allow delegate to consume buffer mutations #4830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,10 @@ | |
from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( | ||
ExecutorBackend, | ||
) | ||
from executorch.exir.backend.test.op_partitioner_demo import ( | ||
AddAttributePartitionerDemo, | ||
AllNodesPartitionerDemo, | ||
) | ||
from executorch.exir.backend.utils import get_delegates, tag_constant_data | ||
|
||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
@@ -619,3 +623,111 @@ def partition( | |
and node.target == torch.ops.aten.copy_.default | ||
] | ||
self.assertEqual(len(copy_node), 1) | ||
|
||
def test_buffer_mutation1(self): | ||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer("b", torch.ones(3, 3)) | ||
|
||
def forward(self, x): | ||
self.b.add_(x) | ||
return x + self.b | ||
|
||
model_inputs = (torch.ones(3, 3),) | ||
orig_res = TestModule()(*model_inputs) | ||
edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs)) | ||
lowered = edge_program.to_backend(AddAttributePartitionerDemo()) | ||
|
||
self.assertTrue( | ||
torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res) | ||
) | ||
|
||
self.assertEqual( | ||
len(lowered.exported_program().graph_signature.buffers_to_mutate), | ||
0, | ||
) | ||
lowered_module_nodes = get_delegates(lowered.exported_program().graph) | ||
self.assertEqual(len(lowered_module_nodes), 1) | ||
lowered_module_node = lowered_module_nodes[0] | ||
|
||
# get call delegate node | ||
call_delegate_node = list(lowered_module_node.users.keys())[0] | ||
self.assertEqual(len(call_delegate_node.args), 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add comment on what the two args are. I presume the first one is delegate blob and second is the input? |
||
|
||
lower_module = getattr( | ||
lowered.exported_program().graph_module, lowered_module_node.name | ||
) | ||
delegated_ep = lower_module.original_module | ||
|
||
self.assertEqual(len(delegated_ep.state_dict), 1) | ||
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) | ||
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) | ||
|
||
def test_buffer_mutation_llama_repro(self): | ||
SHAPE = (2, 3) | ||
|
||
class Model(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32)) | ||
|
||
def forward(self, q, k_val, input_pos): | ||
q_T = q.transpose(0, 1) | ||
k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) | ||
attn = k.mm(q_T) | ||
return attn | ||
|
||
q = torch.rand(1, 3) | ||
k = torch.rand(1, 3) | ||
example_inputs = (q, k, torch.tensor([1, 1])) | ||
|
||
model = Model() | ||
model.eval() | ||
|
||
exir_program_aten = torch.export.export(model, example_inputs) | ||
exir_program_aten.module()(*example_inputs) | ||
edge_program_manager = exir.to_edge(exir_program_aten) | ||
lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo()) | ||
|
||
self.assertEqual( | ||
len(lowered.exported_program().graph_signature.buffers_to_mutate), | ||
0, | ||
) | ||
lowered_module_nodes = get_delegates(lowered.exported_program().graph) | ||
self.assertEqual(len(lowered_module_nodes), 1) | ||
lowered_module_node = lowered_module_nodes[0] | ||
|
||
# get call delegate node | ||
call_delegate_node = list(lowered_module_node.users.keys())[0] | ||
self.assertEqual(len(call_delegate_node.args), 4) | ||
|
||
lower_module = getattr( | ||
lowered.exported_program().graph_module, lowered_module_node.name | ||
) | ||
delegated_ep = lower_module.original_module | ||
|
||
self.assertEqual(len(delegated_ep.state_dict), 1) | ||
self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) | ||
self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) | ||
|
||
def test_buffer_mutation_unsupported(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice. Thanks for this test case |
||
SHAPE = (2, 3) | ||
|
||
class Model(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32)) | ||
|
||
def forward(self, x): | ||
add = self.state_1.add_(x) | ||
return add | ||
|
||
model = Model() | ||
model.eval() | ||
|
||
example_inputs = (torch.randn(SHAPE),) | ||
exir_program_aten = torch.export.export(model, example_inputs) | ||
edge_program_manager = exir.to_edge(exir_program_aten) | ||
with self.assertRaises(AssertionError): | ||
edge_program_manager.to_backend(AddAttributePartitionerDemo()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
|
||
import copy | ||
import operator | ||
from collections import defaultdict | ||
from typing import Any, Dict, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
|
@@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901 | |
else {} | ||
) | ||
|
||
toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list) | ||
if not is_submodule: | ||
for output_spec in old_signature.output_specs: | ||
toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec) | ||
|
||
for node in gm.graph.nodes: | ||
is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag | ||
if node.op == "placeholder": | ||
|
||
if node.name not in input_node_to_sig: | ||
|
@@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901 | |
if not isinstance(orig_input_spec.arg, TensorArgument): | ||
input_specs.append(orig_input_spec) | ||
|
||
elif is_tagged: | ||
elif node.meta.get("delegation_tag", None) == tag: | ||
input_specs.append(orig_input_spec) | ||
|
||
if orig_input_spec.kind == InputKind.USER_INPUT: | ||
|
@@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901 | |
) | ||
|
||
if node.op == "output": | ||
output_nodes = pytree.tree_leaves((node.args, node.kwargs)) | ||
buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) | ||
for user in call_module_node.users.keys(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. THis is checking if the partitioned subgraph's call_module node is returning mutated buffer, right? If so we plan to remove those from call signature of the submodule but also from top level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to remove them from the toplevel module, and in the submodule, set the signature as a "buffer mutation". |
||
if user.name in toplevel_output_node_to_sig: | ||
assert ( | ||
user.op == "call_function" and user.target == operator.getitem | ||
), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}" | ||
getitem_idx = user.args[1] | ||
assert isinstance( | ||
getitem_idx, int | ||
), f"Invalid getitem type: {type(getitem_idx)}" | ||
buffer_mutation_idxs[getitem_idx].extend( | ||
toplevel_output_node_to_sig[user.name] | ||
) | ||
|
||
for output_node in output_nodes: | ||
for i, output_node in enumerate(node.args[0]): | ||
if i in buffer_mutation_idxs: | ||
assert isinstance(output_node, torch.fx.Node) | ||
orig_output_specs = buffer_mutation_idxs[i] | ||
|
||
if any( | ||
orig_output_spec.kind == OutputKind.BUFFER_MUTATION | ||
and orig_output_spec.target in new_state_dict | ||
for orig_output_spec in orig_output_specs | ||
): | ||
# If the delegate wants to consume the buffer, then the | ||
# delegate should also consume the buffer mutation | ||
# (output spec would be a BUFFER_MUTATION). Otherwise | ||
# the delegate will just return the result of the | ||
# mutation as a USER_OUTPUT. | ||
|
||
orig_output_spec = [ | ||
orig_output_spec | ||
for orig_output_spec in orig_output_specs | ||
if orig_output_spec.kind == OutputKind.BUFFER_MUTATION | ||
and orig_output_spec.target in new_state_dict | ||
][0] | ||
|
||
assert len(orig_output_specs) == 1, ( | ||
f"Constant {orig_output_spec.target} was tagged to be " | ||
"consumed by the buffer, and was found to also contain " | ||
"a buffer mutation. However this buffer mutation node " | ||
"was found to also be used as other types of outputs " | ||
"which is currently not supported. Please file an " | ||
"issue on Github. \n\n" | ||
f"The toplevel program: {original_program}\n" | ||
) | ||
output_specs.append( | ||
OutputSpec( | ||
kind=OutputKind.BUFFER_MUTATION, | ||
arg=TensorArgument(name=output_node.name), | ||
target=orig_output_spec.target, | ||
) | ||
) | ||
output_specs_to_delete[orig_output_spec.arg.name] = ( | ||
orig_output_spec | ||
) | ||
else: | ||
output_specs.append( | ||
OutputSpec( | ||
kind=OutputKind.USER_OUTPUT, | ||
arg=TensorArgument(name=output_node.name), | ||
target=None, | ||
) | ||
) | ||
|
||
if not isinstance(output_node, torch.fx.Node): | ||
elif not isinstance(output_node, torch.fx.Node): | ||
output_specs.append( | ||
OutputSpec( | ||
kind=OutputKind.USER_OUTPUT, | ||
|
@@ -630,6 +696,9 @@ def create_exported_program_from_submodule( | |
in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] | ||
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] | ||
|
||
print(submodule.graph) | ||
print(subgraph_signature) | ||
|
||
return ( | ||
ExportedProgram( | ||
root=submodule, | ||
|
@@ -774,7 +843,7 @@ def get_lowered_backend_modules( | |
return lowered_programs | ||
|
||
|
||
def _unsafe_adjust_original_program( | ||
def _unsafe_adjust_original_program( # noqa: C901 | ||
original_program: ExportedProgram, | ||
call_delegate_node: torch.fx.Node, | ||
input_specs_to_delete: Dict[str, InputSpec], | ||
|
@@ -830,3 +899,50 @@ def _unsafe_adjust_original_program( | |
del original_program._constants[input_spec.target] | ||
else: | ||
raise RuntimeError(f"Invalid input spec {input_spec} received") | ||
|
||
# Delete buffer mutations from the output which were consumed by the delegate | ||
toplevel_output_node = None | ||
for node in reversed(original_program.graph.nodes): | ||
if node.op == "output": | ||
toplevel_output_node = node | ||
break | ||
|
||
assert toplevel_output_node is not None | ||
assert ( | ||
len(toplevel_output_node.args) == 1 | ||
), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}" | ||
|
||
new_output_args = [ | ||
arg | ||
for arg in toplevel_output_node.args[0] | ||
if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete | ||
] | ||
toplevel_output_node.args = (tuple(new_output_args),) | ||
|
||
# Delete the buffer mutation getitem nodes | ||
getitem_idxs: List[int] = [] | ||
user_nodes = list(call_delegate_node.users.keys()) | ||
for user in user_nodes: | ||
if user.name in output_specs_to_delete: | ||
assert ( | ||
user.op == "call_function" and user.target == operator.getitem | ||
), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}" | ||
user_idx = user.args[1] | ||
assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}" | ||
getitem_idxs.append(user_idx) | ||
original_program.graph.erase_node(user) | ||
|
||
getitem_idxs.sort(reverse=True) | ||
|
||
# Adjust all the getitem indices after the deleted getitems | ||
user_nodes = list(call_delegate_node.users.keys()) | ||
for user in user_nodes: | ||
assert user.op == "call_function" and user.target == operator.getitem | ||
user_idx = user.args[1] | ||
assert isinstance(user_idx, int) | ||
for i, idx in enumerate(getitem_idxs): | ||
if user_idx > idx: | ||
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) | ||
break | ||
|
||
original_program._validate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clarification: I thought
lowered_module_node
itself iscall_delegate
node? It not what it is?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lowered_module_node is the getattr node that points to the actual saved lowered module. The call_delegate_node is the node that calls the delegate.