12
12
13
13
import executorch .backends .cadence .aot .ops_registrations # noqa
14
14
import torch
15
- from executorch .backends .cadence .aot .compiler import export_to_edge
16
15
from executorch .backends .cadence .aot .graph_builder import single_op_builder
17
16
from executorch .backends .cadence .aot .pass_utils import count_node
18
17
from executorch .backends .cadence .aot .simplify_ops import (
@@ -40,82 +39,47 @@ def test_simplify_slice_scatter_op(
40
39
end : Optional [int ] = None ,
41
40
step : int = 1 ,
42
41
):
43
- class SliceScatter (torch .nn .Module ):
44
- def __init__ (
45
- self , dim : int , start : Optional [int ], end : Optional [int ], step : int
46
- ):
47
- super ().__init__ ()
48
- self .dim = dim
49
- self .start = start
50
- self .end = end
51
- self .step = step
52
-
53
- def forward (self , x : torch .Tensor , y : torch .Tensor ):
54
- return torch .slice_scatter (
55
- x , y , self .dim , self .start , self .end , self .step
56
- )
57
-
58
- model = SliceScatter (dim , start , end , step )
59
- x = torch .randn (in_shape )
60
- y = torch .randn (src_shape )
61
- graph_module = export_to_edge (model , (x , y )).exported_program ().graph_module
62
-
63
- p = SimplifySliceOpPass ()
64
-
65
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
66
-
67
- self .assertEqual (
68
- count_node (graph_after_passes , exir_ops .edge .aten .slice_scatter .default ), 0
42
+ x = torch .randn (* in_shape )
43
+ y = torch .randn (* src_shape )
44
+ gm = single_op_builder (
45
+ placeholders = (x , y ),
46
+ op = exir_ops .edge .aten .slice_scatter .default ,
47
+ args = (x , y , dim , start , end , step ),
69
48
)
49
+ p = SimplifySliceOpPass ()
50
+ gm = cast (PassResult , p (gm )).graph_module
51
+ self .assertEqual (count_node (gm , exir_ops .edge .aten .slice_scatter .default ), 0 )
70
52
71
53
@parameterized .expand (
72
54
[
73
- [(3 , 16 , 5 ), ( 3 , 0 , 5 ), 1 , 15 , 3 , 3 ],
55
+ [(3 , 16 , 5 ), 1 , 15 , 3 , 3 ],
74
56
]
75
57
)
76
58
@torch .no_grad ()
77
59
def test_simplify_slice_op (
78
60
self ,
79
61
in_shape : Tuple [int ],
80
- src_shape : Tuple [int ],
81
62
dim : int ,
82
63
start : Optional [int ] = None ,
83
64
end : Optional [int ] = None ,
84
65
step : int = 1 ,
85
66
):
86
- class SliceCopy (torch .nn .Module ):
87
- def __init__ (
88
- self , dim : int , start : Optional [int ], end : Optional [int ], step : int
89
- ):
90
- super ().__init__ ()
91
- self .dim = dim
92
- self .start = start
93
- self .end = end
94
- self .step = step
95
-
96
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
97
- return torch .slice_copy (
98
- x , dim = self .dim , start = self .start , end = self .end , step = self .step
99
- )
100
-
101
- # Create a model with single slice copy op.
102
- model = SliceCopy (dim , start , end , step )
103
- x = torch .randn (in_shape )
104
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
105
- self .assertEqual (
106
- count_node (graph_module , exir_ops .edge .aten .slice_copy .Tensor ), 1
67
+ x = torch .randn (* in_shape )
68
+ gm = single_op_builder (
69
+ placeholders = (x ,),
70
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
71
+ args = (
72
+ x ,
73
+ dim ,
74
+ start ,
75
+ end ,
76
+ step ,
77
+ ),
107
78
)
108
-
109
79
p = SimplifySliceOpPass ()
110
-
111
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
112
-
113
- self .assertEqual (
114
- count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
115
- )
116
- self .assertEqual (
117
- count_node (graph_after_passes , exir_ops .edge .aten .full .default ), 1
118
- )
80
+ gm = cast (PassResult , p (gm )).graph_module
81
+ self .assertEqual (count_node (gm , exir_ops .edge .aten .slice_copy .Tensor ), 0 )
82
+ self .assertEqual (count_node (gm , exir_ops .edge .aten .full .default ), 1 )
119
83
120
84
def test_simplify_slice_op_args (self ) -> None :
121
85
x = torch .rand (4 , 5 )
@@ -125,24 +89,10 @@ def test_simplify_slice_op_args(self) -> None:
125
89
args = (x , 1 ),
126
90
kwargs = {"end" : 3 },
127
91
)
128
- self .assertEqual (
129
- [
130
- (n .args [1 :], n .kwargs )
131
- for n in gm .graph .find_nodes (
132
- op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
133
- )
134
- ],
135
- [((1 ,), {"end" : 3 })],
136
- )
137
-
92
+ original_slice_copy = list (gm .graph .nodes )[1 ]
93
+ self .assertEqual (original_slice_copy .args [1 :], (1 ,))
94
+ self .assertEqual (original_slice_copy .kwargs , {"end" : 3 })
138
95
gm = BindOptionalArgsPass ().call (gm ).graph_module
139
-
140
- self .assertEqual (
141
- [
142
- (n .args [1 :], n .kwargs )
143
- for n in gm .graph .find_nodes (
144
- op = "call_function" , target = exir_ops .edge .aten .slice_copy .Tensor
145
- )
146
- ],
147
- [((1 , None , 3 , 1 ), {})],
148
- )
96
+ modified_slice_copy = list (gm .graph .nodes )[1 ]
97
+ self .assertEqual (modified_slice_copy .args [1 :], (1 , None , 3 , 1 ))
98
+ self .assertEqual (modified_slice_copy .kwargs , {})
0 commit comments