Skip to content

Commit b4bd556

Browse files
authored
Use GraphBuilder in memory passes unit tests. # 2
Differential Revision: D75698567 Pull Request resolved: #11292
1 parent 02e5c58 commit b4bd556

File tree

1 file changed

+213
-102
lines changed

1 file changed

+213
-102
lines changed

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 213 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,13 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
224224

225225
# Initializes the nodes metadata and runs the GenerateMemoryViewConstraints,
226226
# GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes.
227-
def run_memory_planning(self, original, alloc_graph_input=True) -> GraphModule:
227+
def run_memory_planning(
228+
self, original, opt_level=2, alloc_graph_input=True
229+
) -> GraphModule:
228230
graph_module = SpecPropPass().call(original).graph_module
229231
return CadenceMemoryPlanning(
230232
get_default_memory_config(),
231-
opt_level=2,
233+
opt_level=opt_level,
232234
mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy
233235
alloc_graph_input=alloc_graph_input,
234236
)(graph_module).graph_module
@@ -535,130 +537,239 @@ def test_optimize_cat_with_slice_infeasible(self) -> None:
535537
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
536538
self.verify_nop_memory_alloc(graph_module)
537539

538-
def test_optimize_slice_Tensor(self) -> None:
539-
class SliceTensor(torch.nn.Module):
540-
def forward(self, x, y, z):
541-
x1 = torch.add(x, 2.4, 3.1)
542-
# This slice should always be optimized, since x1 is not placeholder
543-
# and the slice is along the outermost dim
544-
t1 = torch.ops.aten.slice(x1, 0, 1, 2)
545-
# This slice should not be optimized when alloc_graph_input=False,
546-
# since y is a placeholder node
547-
t2 = torch.ops.aten.slice(y, 0, 0, 1)
548-
# This slice should be always optimized, since the dims before
549-
# sliced dims are 1
550-
z1 = torch.add(z, 2.4, 3.1)
551-
t3 = torch.ops.aten.slice(z1, 1, 4, 5)
552-
return (t1 + t2) * t3
553-
554-
x = torch.ones(3, 6)
555-
y = torch.ones(2, 6)
556-
z = torch.ones(1, 6)
557-
# Run the memory planning pass and get the graph module
558-
graph_module = (
559-
compiler.export_to_executorch_gen_etrecord(
560-
SliceTensor(),
561-
(x, y, z),
562-
opt_level=2,
563-
mem_algo=1,
564-
alloc_graph_input=False,
565-
)
566-
.exported_program()
567-
.graph_module
540+
def test_optimize_slice_outermost(self) -> None:
541+
builder = GraphBuilder()
542+
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
543+
to_add_to_x = builder.call_operator(
544+
op=exir_ops.edge.aten.full.default,
545+
args=([3, 6], 123.0),
546+
kwargs={"dtype": torch.float32},
547+
)
548+
add_x = builder.call_operator(
549+
op=exir_ops.edge.aten.add.Tensor,
550+
args=(x, to_add_to_x),
551+
)
552+
slice_out = builder.call_operator(
553+
op=exir_ops.edge.aten.full.default,
554+
args=([1, 6], 0.0),
555+
kwargs={"dtype": torch.float32},
568556
)
557+
# This slice should always be optimized, since add_x is not placeholder
558+
# and the slice is along the outermost dim
559+
slice_result = builder.call_operator(
560+
op=torch.ops.aten.slice_copy.Tensor_out,
561+
args=(
562+
add_x,
563+
0, # dim
564+
1, # start
565+
2, # end
566+
1, # step
567+
),
568+
kwargs={"out": slice_out},
569+
)
570+
builder.output([slice_result])
571+
original = builder.get_graph_module()
572+
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
569573
graph_module.graph.eliminate_dead_code()
570-
# Assert that t2 is not optimized away
571574
self.assertEqual(
572-
count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out), 1
575+
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
576+
)
577+
self.verify_nop_memory_alloc(graph_module)
578+
579+
def test_optimize_slice_non_outermost(self) -> None:
580+
builder = GraphBuilder()
581+
x = builder.placeholder("x", torch.ones(1, 6, dtype=torch.float32))
582+
to_add_to_x = builder.call_operator(
583+
op=exir_ops.edge.aten.full.default,
584+
args=([1, 6], 123.0),
585+
kwargs={"dtype": torch.float32},
586+
)
587+
add_x = builder.call_operator(
588+
op=exir_ops.edge.aten.add.Tensor,
589+
args=(x, to_add_to_x),
590+
)
591+
slice_out = builder.call_operator(
592+
op=exir_ops.edge.aten.full.default,
593+
args=([1, 2], 0.0),
594+
kwargs={"dtype": torch.float32},
595+
)
596+
# This slice should be always optimized, since the dims before
597+
# sliced dims are 1.
598+
slice_result = builder.call_operator(
599+
op=torch.ops.aten.slice_copy.Tensor_out,
600+
args=(
601+
add_x,
602+
1, # dim
603+
4, # start
604+
6, # end
605+
1, # step
606+
),
607+
kwargs={"out": slice_out},
573608
)
574-
# Assert that t1 and t3 are optimized to slice_copy_nop veresion
609+
builder.output([slice_result])
610+
original = builder.get_graph_module()
611+
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
612+
graph_module.graph.eliminate_dead_code()
575613
self.assertEqual(
576-
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
614+
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
577615
)
616+
self.verify_nop_memory_alloc(graph_module)
617+
618+
def test_optimize_slice_depending_on_opt_level(self) -> None:
619+
builder = GraphBuilder()
620+
x = builder.placeholder("x", torch.ones(2, 6, dtype=torch.float32))
621+
slice_out = builder.call_operator(
622+
op=exir_ops.edge.aten.full.default,
623+
args=([1, 6], 0.0),
624+
kwargs={"dtype": torch.float32},
625+
)
626+
# This slice should not be optimized when alloc_graph_input=False,
627+
# since y is a placeholder node
628+
slice_result = builder.call_operator(
629+
op=torch.ops.aten.slice_copy.Tensor_out,
630+
args=(
631+
x,
632+
0, # dim
633+
0, # start
634+
1, # end
635+
1, # step
636+
),
637+
kwargs={"out": slice_out},
638+
)
639+
builder.output([slice_result])
640+
original = builder.get_graph_module()
641+
graph_module = self.run_memory_planning(
642+
original, opt_level=2, alloc_graph_input=False
643+
)
644+
graph_module.graph.eliminate_dead_code()
645+
self.assertEqual(
646+
count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out), 1
647+
)
648+
self.verify_nop_memory_alloc(graph_module)
649+
578650
# When we compile with alloc_graph_input=True, all the slice ops must
579-
# be optimized.
580-
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
581-
# pass to run:
582-
graph_module = (
583-
compiler.export_to_executorch_gen_etrecord(
584-
SliceTensor(),
585-
(x, y, z),
586-
opt_level=3,
587-
mem_algo=1,
588-
alloc_graph_input=True,
589-
)
590-
.exported_program()
591-
.graph_module
651+
# be optimized, which is available only at opt_level 2+.
652+
graph_module = self.run_memory_planning(
653+
original, opt_level=3, alloc_graph_input=True
592654
)
593655
graph_module.graph.eliminate_dead_code()
594-
self.assertFalse(count_node(graph_module, torch.ops.aten.slice_copy.Tensor_out))
595656
self.assertEqual(
596-
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
657+
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1
597658
)
598659
self.verify_nop_memory_alloc(graph_module)
599660

600-
def test_optimize_select_Tensor(self) -> None:
601-
class SelectTensor(torch.nn.Module):
602-
def forward(self, x, y, z):
603-
x1 = torch.add(x, 2.4, 3.1)
604-
# This select should always be optimized, since x1 is not
605-
# placeholder, and the select is along the outermost dim
606-
t1 = torch.select_copy(x1, 0, 1)
607-
# This select should not be optimized if alloc_graph_input=False,
608-
# since y is a placeholder node.
609-
t2 = torch.select_copy(y, 0, 0)
610-
# This select should always be optimized, since the dims before
611-
# select dims are 1
612-
z1 = torch.add(z, 2.4, 3.1)
613-
t3 = torch.select(z1, 1, 4)
614-
return (t1 + t2) * t3
615-
616-
x = torch.ones(3, 6)
617-
y = torch.ones(2, 6)
618-
z = torch.ones(1, 6)
619-
# Optimizing select ops is only at opt_level 2+, and requires the memory planning
620-
# pass to run:
621-
graph_module = (
622-
compiler.export_to_executorch_gen_etrecord(
623-
SelectTensor(),
624-
(x, y, z),
625-
opt_level=2,
626-
mem_algo=1,
627-
alloc_graph_input=False,
628-
)
629-
.exported_program()
630-
.graph_module
661+
def test_optimize_select_outermost(self) -> None:
662+
builder = GraphBuilder()
663+
x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32))
664+
to_add_to_x = builder.call_operator(
665+
op=exir_ops.edge.aten.full.default,
666+
args=([3, 6], 123.0),
667+
kwargs={"dtype": torch.float32},
631668
)
669+
add_x = builder.call_operator(
670+
op=exir_ops.edge.aten.add.Tensor,
671+
args=(x, to_add_to_x),
672+
)
673+
slice_out = builder.call_operator(
674+
op=exir_ops.edge.aten.full.default,
675+
args=([1, 6], 0.0),
676+
kwargs={"dtype": torch.float32},
677+
)
678+
# This select should always be optimized, since add_x is not placeholder
679+
# and the select is along the outermost dim
680+
slice_result = builder.call_operator(
681+
op=torch.ops.aten.select_copy.int_out,
682+
args=(
683+
add_x,
684+
0, # dim
685+
1, # index
686+
),
687+
kwargs={"out": slice_out},
688+
)
689+
builder.output([slice_result])
690+
original = builder.get_graph_module()
691+
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
632692
graph_module.graph.eliminate_dead_code()
633-
# Assert that t2 is not optimized away
634693
self.assertEqual(
635-
count_node(graph_module, torch.ops.aten.select_copy.int_out), 1
694+
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
695+
)
696+
self.verify_nop_memory_alloc(graph_module)
697+
698+
def test_optimize_select_non_outermost(self) -> None:
699+
builder = GraphBuilder()
700+
x = builder.placeholder("x", torch.ones(1, 6, dtype=torch.float32))
701+
to_add_to_x = builder.call_operator(
702+
op=exir_ops.edge.aten.full.default,
703+
args=([1, 6], 123.0),
704+
kwargs={"dtype": torch.float32},
705+
)
706+
add_x = builder.call_operator(
707+
op=exir_ops.edge.aten.add.Tensor,
708+
args=(x, to_add_to_x),
709+
)
710+
slice_out = builder.call_operator(
711+
op=exir_ops.edge.aten.full.default,
712+
args=([1, 2], 0.0),
713+
kwargs={"dtype": torch.float32},
714+
)
715+
# This select should always be optimized, since the dims before
716+
# select dims are 1
717+
slice_result = builder.call_operator(
718+
op=torch.ops.aten.select_copy.int_out,
719+
args=(
720+
add_x,
721+
1, # dim
722+
4, # index
723+
),
724+
kwargs={"out": slice_out},
636725
)
637-
# Assert that t1 and t3 are optimized to select_copy_nop veresion
726+
builder.output([slice_result])
727+
original = builder.get_graph_module()
728+
graph_module = self.run_memory_planning(original, alloc_graph_input=False)
729+
graph_module.graph.eliminate_dead_code()
638730
self.assertEqual(
639-
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 2
731+
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
640732
)
641-
# When we compile with alloc_graph_input=True, all the select ops must
642-
# be optimized.
643-
# Optimizing select ops is only at opt_level 2+, and requires the memory planning
644-
# pass to run:
645-
graph_module = (
646-
compiler.export_to_executorch_gen_etrecord(
647-
SelectTensor(),
648-
(x, y, z),
649-
opt_level=3,
650-
mem_algo=1,
651-
alloc_graph_input=True,
652-
)
653-
.exported_program()
654-
.graph_module
733+
self.verify_nop_memory_alloc(graph_module)
734+
735+
def test_optimize_select_depending_on_opt_level(self) -> None:
736+
builder = GraphBuilder()
737+
x = builder.placeholder("x", torch.ones(2, 6, dtype=torch.float32))
738+
slice_out = builder.call_operator(
739+
op=exir_ops.edge.aten.full.default,
740+
args=([1, 6], 0.0),
741+
kwargs={"dtype": torch.float32},
742+
)
743+
# This select should not be optimized if alloc_graph_input=False,
744+
# since y is a placeholder node.
745+
slice_result = builder.call_operator(
746+
op=torch.ops.aten.select_copy.int_out,
747+
args=(
748+
x,
749+
0, # dim
750+
0, # index
751+
),
752+
kwargs={"out": slice_out},
753+
)
754+
builder.output([slice_result])
755+
original = builder.get_graph_module()
756+
graph_module = self.run_memory_planning(
757+
original, opt_level=2, alloc_graph_input=False
655758
)
656759
graph_module.graph.eliminate_dead_code()
657760
self.assertEqual(
658-
count_node(graph_module, torch.ops.aten.select_copy.int_out), 0
761+
count_node(graph_module, torch.ops.aten.select_copy.int_out), 1
659762
)
763+
self.verify_nop_memory_alloc(graph_module)
764+
765+
# When we compile with alloc_graph_input=True, all the slice ops must
766+
# be optimized, which is available only at opt_level 2+.
767+
graph_module = self.run_memory_planning(
768+
original, opt_level=3, alloc_graph_input=True
769+
)
770+
graph_module.graph.eliminate_dead_code()
660771
self.assertEqual(
661-
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
772+
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 1
662773
)
663774
self.verify_nop_memory_alloc(graph_module)
664775

0 commit comments

Comments
 (0)