Skip to content

Commit bad2fa9

Browse files
Arm backend: Refactor any, bitwise, logical tests (#9499)
- Rename bitwise and logical tests with full aten op name - Refactor the tests with test_pipeline and new Xfail decorator - Fix the naming error in test_any Signed-off-by: Yufeng Shi <yufeng.shi@arm.com>
1 parent 69cc7fa commit bad2fa9

File tree

3 files changed

+308
-300
lines changed

3 files changed

+308
-300
lines changed

backends/arm/test/ops/test_any.py

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import List, Tuple
88

9+
import pytest
910
import torch
1011
from executorch.backends.arm.test import common
1112
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -117,25 +118,6 @@ def forward(self, x: torch.Tensor):
117118
}
118119

119120

120-
fvp_xfails = {
121-
"any_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.",
122-
"any_rank1_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
123-
"any_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
124-
"any_rank2_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
125-
"any_rank2_dims": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
126-
"any_rank2_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
127-
"any_rank3_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
128-
"any_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
129-
"any_rank4_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
130-
"any_rank4_dims": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
131-
"any_rank4_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
132-
"any_rank1_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
133-
"any_rank2_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
134-
"any_rank3_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
135-
"any_rank4_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.",
136-
}
137-
138-
139121
@common.parametrize("test_data", test_data)
140122
def test_any_tosa_MI(test_data: input_t1):
141123
op, test_input = test_data
@@ -147,13 +129,13 @@ def test_any_tosa_MI(test_data: input_t1):
147129
def test_any_tosa_BI(test_data: input_t1):
148130
op, test_input = test_data
149131
pipeline = TosaPipelineBI[input_t1](op, test_input, op.aten_op, op.exir_op)
150-
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
151132
pipeline.pop_stage("quantize")
133+
pipeline.pop_stage("check.quant_nodes")
152134
pipeline.run()
153135

154136

155137
@common.parametrize("test_data", test_data)
156-
def test_logical_u55_BI(test_data: input_t1):
138+
def test_any_u55_BI(test_data: input_t1):
157139
# Tests that we don't delegate these ops since they are not supported on U55.
158140
op, test_input = test_data
159141
pipeline = OpNotSupportedPipeline[input_t1](
@@ -163,23 +145,13 @@ def test_logical_u55_BI(test_data: input_t1):
163145

164146

165147
@common.parametrize("test_data", test_data)
166-
def test_floor_u85_BI(test_data: input_t1):
167-
op, test_input = test_data
168-
pipeline = EthosU85PipelineBI[input_t1](
169-
op, test_input, op.aten_op, op.exir_op, run_on_fvp=False
170-
)
171-
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
172-
pipeline.pop_stage("quantize")
173-
pipeline.run()
174-
175-
176-
@common.parametrize("test_data", test_data, fvp_xfails)
177-
@common.SkipIfNoCorstone320
178-
def test_floor_u85_BI_on_fvp(test_data: input_t1):
148+
@pytest.mark.xfail(reason="MLETORCH-706: Support ScalarType::Bool in EthosUBackend.")
149+
@common.XfailIfNoCorstone320
150+
def test_any_u85_BI(test_data: input_t1):
179151
op, test_input = test_data
180152
pipeline = EthosU85PipelineBI[input_t1](
181153
op, test_input, op.aten_op, op.exir_op, run_on_fvp=True
182154
)
183-
pipeline.pop_stage(pipeline.find_pos("quantize") + 1)
184155
pipeline.pop_stage("quantize")
156+
pipeline.pop_stage("check.quant_nodes")
185157
pipeline.run()

0 commit comments

Comments
 (0)