55
66from typing import Tuple
77
8- import pytest
98import torch
109from executorch .backends .arm .test import common
1110
1615 TosaPipelineMI ,
1716)
1817
19- aten_op = "torch.ops.aten.gt.Tensor"
20- exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
2118
2219input_t = Tuple [torch .Tensor ]
2320
2421
2522class Greater (torch .nn .Module ):
23+ aten_op_tensor = "torch.ops.aten.gt.Tensor"
24+ aten_op_scalar = "torch.ops.aten.gt.Scalar"
25+ exir_op = "executorch_exir_dialects_edge__ops_aten_gt_Tensor"
26+
2627 def __init__ (self , input , other ):
2728 super ().__init__ ()
2829 self .input_ = input
@@ -31,106 +32,143 @@ def __init__(self, input, other):
3132 def forward (
3233 self ,
3334 input_ : torch .Tensor ,
34- other_ : torch .Tensor ,
35+ other_ : torch .Tensor | int | float ,
3536 ):
3637 return input_ > other_
3738
3839 def get_inputs (self ):
3940 return (self .input_ , self .other_ )
4041
4142
42- op_gt_rank1_ones = Greater (
43+ op_gt_tensor_rank1_ones = Greater (
4344 torch .ones (5 ),
4445 torch .ones (5 ),
4546)
46- op_gt_rank2_rand = Greater (
47+ op_gt_tensor_rank2_rand = Greater (
4748 torch .rand (4 , 5 ),
4849 torch .rand (1 , 5 ),
4950)
50- op_gt_rank3_randn = Greater (
51+ op_gt_tensor_rank3_randn = Greater (
5152 torch .randn (10 , 5 , 2 ),
5253 torch .randn (10 , 5 , 2 ),
5354)
54- op_gt_rank4_randn = Greater (
55+ op_gt_tensor_rank4_randn = Greater (
5556 torch .randn (3 , 2 , 2 , 2 ),
5657 torch .randn (3 , 2 , 2 , 2 ),
5758)
5859
59- test_data_common = {
60- "gt_rank1_ones" : op_gt_rank1_ones ,
61- "gt_rank2_rand" : op_gt_rank2_rand ,
62- "gt_rank3_randn" : op_gt_rank3_randn ,
63- "gt_rank4_randn" : op_gt_rank4_randn ,
60+ op_gt_scalar_rank1_ones = Greater (torch .ones (5 ), 1.0 )
61+ op_gt_scalar_rank2_rand = Greater (torch .rand (4 , 5 ), 0.2 )
62+ op_gt_scalar_rank3_randn = Greater (torch .randn (10 , 5 , 2 ), - 0.1 )
63+ op_gt_scalar_rank4_randn = Greater (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
64+
65+ test_data_tensor = {
66+ "gt_tensor_rank1_ones" : op_gt_tensor_rank1_ones ,
67+ "gt_tensor_rank2_rand" : op_gt_tensor_rank2_rand ,
68+ "gt_tensor_rank3_randn" : op_gt_tensor_rank3_randn ,
69+ "gt_tensor_rank4_randn" : op_gt_tensor_rank4_randn ,
70+ }
71+
72+ test_data_scalar = {
73+ "gt_scalar_rank1_ones" : op_gt_scalar_rank1_ones ,
74+ "gt_scalar_rank2_rand" : op_gt_scalar_rank2_rand ,
75+ "gt_scalar_rank3_randn" : op_gt_scalar_rank3_randn ,
76+ "gt_scalar_rank4_randn" : op_gt_scalar_rank4_randn ,
6477}
6578
6679
67- @common .parametrize ("test_module" , test_data_common )
68- def test_gt_tosa_MI (test_module ):
80+ @common .parametrize ("test_module" , test_data_tensor )
81+ def test_gt_tensor_tosa_MI (test_module ):
82+ pipeline = TosaPipelineMI [input_t ](
83+ test_module , test_module .get_inputs (), Greater .aten_op_tensor , Greater .exir_op
84+ )
85+ pipeline .run ()
86+
87+
88+ @common .parametrize ("test_module" , test_data_scalar )
89+ def test_gt_scalar_tosa_MI (test_module ):
6990 pipeline = TosaPipelineMI [input_t ](
70- test_module , test_module .get_inputs (), aten_op , exir_op
91+ test_module , test_module .get_inputs (), Greater .aten_op_scalar , Greater .exir_op
92+ )
93+ pipeline .run ()
94+
95+
96+ @common .parametrize ("test_module" , test_data_tensor )
97+ def test_gt_tensor_tosa_BI (test_module ):
98+ pipeline = TosaPipelineBI [input_t ](
99+ test_module , test_module .get_inputs (), Greater .aten_op_tensor , Greater .exir_op
71100 )
72101 pipeline .run ()
73102
74103
75- @common .parametrize ("test_module" , test_data_common )
76- def test_gt_tosa_BI (test_module ):
104+ @common .parametrize ("test_module" , test_data_scalar )
105+ def test_gt_scalar_tosa_BI (test_module ):
77106 pipeline = TosaPipelineBI [input_t ](
78- test_module , test_module .get_inputs (), aten_op , exir_op
107+ test_module , test_module .get_inputs (), Greater . aten_op_tensor , Greater . exir_op
79108 )
80109 pipeline .run ()
81110
82111
83- @common .parametrize ("test_module" , test_data_common )
84- def test_gt_u55_BI (test_module ):
85- # GREATER is not supported on U55.
112+ @common .parametrize ("test_module" , test_data_tensor )
113+ @common .XfailIfNoCorstone300
114+ def test_gt_tensor_u55_BI (test_module ):
115+ # Greater is not supported on U55.
86116 pipeline = OpNotSupportedPipeline [input_t ](
87117 test_module ,
88118 test_module .get_inputs (),
89119 "TOSA-0.80+BI+u55" ,
90- {exir_op : 1 },
120+ {Greater . exir_op : 1 },
91121 )
92122 pipeline .run ()
93123
94124
95- @common .parametrize ("test_module" , test_data_common )
96- def test_gt_u85_BI (test_module ):
97- pipeline = EthosU85PipelineBI [input_t ](
125+ @common .parametrize ("test_module" , test_data_scalar )
126+ @common .XfailIfNoCorstone300
127+ def test_gt_scalar_u55_BI (test_module ):
128+ # Greater is not supported on U55.
129+ pipeline = OpNotSupportedPipeline [input_t ](
98130 test_module ,
99131 test_module .get_inputs (),
100- aten_op ,
101- exir_op ,
102- run_on_fvp = False ,
103- use_to_edge_transform_and_lower = True ,
132+ "TOSA-0.80+BI+u55" ,
133+ {Greater .exir_op : 1 },
134+ n_expected_delegates = 1 ,
104135 )
105136 pipeline .run ()
106137
107138
108- @common .parametrize ("test_module" , test_data_common )
109- @pytest .mark .skip (reason = "The same as test_gt_u55_BI" )
110- def test_gt_u55_BI_on_fvp (test_module ):
111- # GREATER is not supported on U55.
112- pipeline = OpNotSupportedPipeline [input_t ](
139+ @common .parametrize (
140+ "test_module" ,
141+ test_data_tensor ,
142+ xfails = {
143+ "gt_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
144+ },
145+ )
146+ @common .XfailIfNoCorstone320
147+ def test_gt_tensor_u85_BI (test_module ):
148+ pipeline = EthosU85PipelineBI [input_t ](
113149 test_module ,
114150 test_module .get_inputs (),
115- "TOSA-0.80+BI+u55" ,
116- {exir_op : 1 },
151+ Greater .aten_op_tensor ,
152+ Greater .exir_op ,
153+ run_on_fvp = True ,
117154 )
118155 pipeline .run ()
119156
120157
121158@common .parametrize (
122159 "test_module" ,
123- test_data_common ,
124- xfails = {"gt_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" },
160+ test_data_scalar ,
161+ xfails = {
162+ "gt_scalar_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" ,
163+ },
125164)
126- @common .SkipIfNoCorstone320
127- def test_gt_u85_BI_on_fvp (test_module ):
165+ @common .XfailIfNoCorstone320
166+ def test_gt_scalar_u85_BI (test_module ):
128167 pipeline = EthosU85PipelineBI [input_t ](
129168 test_module ,
130169 test_module .get_inputs (),
131- aten_op ,
132- exir_op ,
170+ Greater . aten_op_tensor ,
171+ Greater . exir_op ,
133172 run_on_fvp = True ,
134- use_to_edge_transform_and_lower = True ,
135173 )
136174 pipeline .run ()
0 commit comments