Skip to content

Commit 118f575

Browse files
mcr229facebook-github-bot
authored andcommitted
fix bug with sequential backends
Summary: https://github.com/pytorch/executorch/pull/10584/files#r2070213706 there's a bug described in this PR comment. I add some tests and a fix to cover it. Essentially when sequential partitions go through preprocess_all, the get_item nodes from the first partition in the sequence don't correctly get mapped to the arguments input into the second partition. This is because the name of these nodes change (the original node to a get_item node). Instead of checking for the names, we instead delete the nodes we know must be deleted from the inputspec Additionaly, there is an issue with validation. the _validate fails when there are call_module nodes still in the graph. Since preprocess_multimethod will lower the call_submodule nodes one-by-one calling _validate before all the call_submodule nodes are transformed to call_delegate nodes will fail. We remove the _validate call from unsafe_adjust_original_program and instead call _validate on the original program after all the submodule nodes have been converted to call_delegate Differential Revision: D74226258
1 parent 4dfddf5 commit 118f575

File tree

4 files changed

+126
-20
lines changed

4 files changed

+126
-20
lines changed

exir/backend/backend_api.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,16 @@ def _insert_lowered_submodule(
204204
owning_graph_module = call_submodule_node.graph.owning_module
205205
# call delegate args should only use user_inputs
206206
call_delegate_args = []
207-
# Preserve input order as user_inputs
208-
for inp_name in submodule_program.graph_signature.user_inputs:
209-
for inp_node in call_submodule_node.all_input_nodes:
210-
if inp_node.name == inp_name:
211-
call_delegate_args.append(inp_node)
212-
break
207+
# names of input_specs to delete
208+
input_specs_to_delete = toplevel_input_specs_to_delete
209+
# Delete owned constants from the call_submodule_node args
210+
for call_sm_input in call_submodule_node.args:
211+
if (
212+
isinstance(call_sm_input, torch.fx.Node)
213+
and call_sm_input.name in input_specs_to_delete.keys()
214+
):
215+
continue
216+
call_delegate_args.append(call_sm_input)
213217

214218
def generate_debug_handle(ep: ExportedProgram) -> int:
215219
"""
@@ -324,6 +328,7 @@ def _partition_and_lower_one_graph_module(
324328
toplevel_input_specs_to_delete,
325329
toplevel_output_specs_to_delete,
326330
)
331+
owning_program._validate()
327332

328333
return tagged_graph_module
329334

@@ -742,6 +747,7 @@ def to_backend(
742747
for method_name in method_to_edge_program.keys():
743748
if method_name in method_to_tagged_exported_program:
744749
tagged_exported_program = method_to_tagged_exported_program[method_name]
750+
tagged_exported_program._validate()
745751
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
746752
root=tagged_exported_program.graph_module,
747753
graph=tagged_exported_program.graph_module.graph,

exir/backend/test/backend_with_preprocess_all_demo.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,30 @@
2121
)
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.graph_module import get_control_flow_submodules
24+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
2425
from torch.export.exported_program import ExportedProgram
2526
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2627

2728

29+
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
30+
return (
31+
is_param(exp_prog, node)
32+
or is_buffer(exp_prog, node)
33+
or is_lifted_tensor_constant(exp_prog, node)
34+
)
35+
36+
37+
def get_total_num_ops_in_ep(edge_programs, supported_ops):
38+
total_number_of_ops = 0
39+
for edge_program in edge_programs.values():
40+
for partitioned_program in edge_program:
41+
for node in partitioned_program.graph.nodes:
42+
if node.op == "call_function":
43+
if node.target in supported_ops:
44+
total_number_of_ops += 1
45+
return total_number_of_ops
46+
47+
2848
def _preprocess_multimethod(
2949
edge_programs: Dict[str, List[ExportedProgram]],
3050
compile_specs: Dict[str, List[List[CompileSpec]]],
@@ -37,13 +57,7 @@ def _preprocess_multimethod(
3757
in testing for a partitioner which tags different partitions for different backends
3858
to be lowered to
3959
"""
40-
total_number_of_ops = 0
41-
for edge_program in edge_programs.values():
42-
for partitioned_program in edge_program:
43-
for node in partitioned_program.graph.nodes:
44-
if node.op == "call_function":
45-
if node.target in supported_ops:
46-
total_number_of_ops += 1
60+
total_number_of_ops = get_total_num_ops_in_ep(edge_programs, supported_ops)
4761
all_processed_results = {key: [] for key in edge_programs.keys()}
4862

4963
for method_name, partitioned_programs in edge_programs.items():
@@ -67,6 +81,8 @@ def _preprocess_multimethod(
6781
raise RuntimeError(
6882
f"{node.op} {node.target.__name__} is not supported in backend {backend_name}"
6983
)
84+
if is_param_node(partitioned_program, node):
85+
processed_bytes += f"CONST{node.name}:"
7086

7187
processed_bytes += "#"
7288
for cs in compile_spec_for_partition:
@@ -171,14 +187,30 @@ def preprocess_multimethod(
171187

172188

173189
class AddSinOperatorSupport(OperatorSupportBase):
190+
def __init__(self, original_program):
191+
self.original_program = original_program
192+
super().__init__()
193+
174194
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
175-
return node.op == "call_function" and node.target in [
195+
supported_targets = [
176196
exir_ops.edge.aten.add.Tensor,
177197
exir_ops.edge.aten.sin.default,
178198
]
199+
if node.op == "call_function" and node.target in supported_targets:
200+
return True
201+
202+
if node.op == "placeholder" and is_param_node(self.original_program, node):
203+
for user in node.users.keys():
204+
if user.target in supported_targets:
205+
return True
206+
return False
179207

180208

181209
class SubCosOperatorSupport(OperatorSupportBase):
210+
def __init__(self, original_program):
211+
self.original_program = original_program
212+
super().__init__()
213+
182214
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
183215
return node.op == "call_function" and node.target in [
184216
exir_ops.edge.aten.sub.Tensor,
@@ -199,11 +231,8 @@ class BackendWithPreprocessAllPartitioner(Partitioner):
199231
"""
200232

201233
def __init__(self) -> None:
202-
self.add_sin_support = any_chain(AddSinOperatorSupport())
203-
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
204-
205-
self.sub_cos_support = any_chain(SubCosOperatorSupport())
206234
self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__
235+
self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__
207236

208237
def _partition_graph_module(
209238
self,
@@ -260,6 +289,8 @@ def _partition_graph_module(
260289
return partition_tags, start_idx_for_submodules
261290

262291
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
292+
self.add_sin_support = any_chain(AddSinOperatorSupport(exported_program))
293+
self.sub_cos_support = any_chain(SubCosOperatorSupport(exported_program))
263294
partition_tags, _ = self._partition_graph_module(exported_program.graph_module)
264295
return PartitionResult(
265296
tagged_exported_program=exported_program, partition_tags=partition_tags

exir/backend/test/test_to_backend_multi_method.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,77 @@ def forward(self, x):
392392
}
393393
self._test(test_set)
394394

395+
def test_multi_method_to_backend_sequential_delegates(self):
396+
class SequentialBackendModule(torch.nn.Module):
397+
def __init__(self):
398+
super().__init__()
399+
400+
def forward(self, x, y, z):
401+
# delegate one
402+
x = x - x
403+
y = y - y
404+
z = z - z
405+
# graph break
406+
a = x * y * z
407+
# delegate two uses outputs from delegate one and the
408+
# output from the graph break
409+
b = x + a
410+
b = b + z + a
411+
b = b + y + a
412+
return b
413+
414+
module = SequentialBackendModule()
415+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
416+
seq_edgeir_m = to_edge(torch.export.export(module, example_inputs))
417+
418+
test_set = {
419+
"seq_edgeir": (
420+
seq_edgeir_m.exported_program(),
421+
BackendWithPreprocessAllPartitioner(),
422+
[
423+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
424+
"FirstBackendWithPreprocessAll#5#aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
425+
],
426+
),
427+
}
428+
self._test(test_set)
429+
430+
def test_multi_method_to_backend_constants(self):
431+
class SequentialBackendModule(torch.nn.Module):
432+
def __init__(self):
433+
super().__init__()
434+
self.const = torch.zeros(1)
435+
436+
def forward(self, x, y, z):
437+
# delegate one
438+
x = x - x
439+
y = y - y
440+
z = z - z
441+
# graph break
442+
a = x * y * z * self.const
443+
# delegate two uses outputs from delegate one and the
444+
# output from the graph break
445+
b = x + self.const + a
446+
b = z + a + b
447+
b = y + a + b
448+
return b
449+
450+
module = SequentialBackendModule()
451+
example_inputs = (torch.ones(1), torch.ones(1), torch.ones(1))
452+
seq_const_m = to_edge(torch.export.export(module, example_inputs))
453+
454+
test_set = {
455+
"seq_const": (
456+
seq_const_m.exported_program(),
457+
BackendWithPreprocessAllPartitioner(),
458+
[
459+
"SecondBackendWithPreprocessAll#3#aten.sub.Tensor:aten.sub.Tensor:aten.sub.Tensor:#sub:b'\\x02';sub:b'\\x02';sub:b'\\x02';",
460+
"FirstBackendWithPreprocessAll#6#CONSTc_const_copy_0:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:aten.add.Tensor:#add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';add:b'\\x00';",
461+
],
462+
),
463+
}
464+
self._test(test_set)
465+
395466
def test_multi_method_to_backend_not_found(self):
396467
class SinModule(torch.nn.Module):
397468
def __init__(self):

exir/lowered_backend_module.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,5 +958,3 @@ def _unsafe_adjust_original_program( # noqa: C901
958958
if user_idx > idx:
959959
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
960960
break
961-
962-
original_program._validate()

0 commit comments

Comments
 (0)