Skip to content

Commit 9e9b0ec

Browse files
committed
Arm Backend: Fixes related to pytorch updates
- Update to error message for unsupported types - Check file extension in test_debug_feats - Adding failures for some unsupported operators in tests
1 parent 4987f0b commit 9e9b0ec

File tree

4 files changed

+22
-12
lines changed

4 files changed

+22
-12
lines changed

backends/arm/operators/op_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141

4242
if inputs[0].dtype not in (ts.DType.INT8, ts.DType.INT16):
4343
raise ValueError(
44-
f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0]]}"
44+
f"TOSA.TABLE only supports int8 or int16 inputs, got {ts.DTypeNames[inputs[0].dtype]}"
4545
)
4646

4747
table = self._exported_program.state_dict[node.name] # type: ignore[union-attr]

backends/arm/test/misc/test_debug_feats.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -192,16 +192,13 @@ def test_collate_tosa_BI_tests(self):
192192
.to_edge_transform_and_lower()
193193
.to_executorch()
194194
)
195+
196+
test_collate_dir = "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
195197
# test that the output directory is created and contains the expected files
196-
assert os.path.exists(
197-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
198-
)
199-
assert os.path.exists(
200-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag6_TOSA-0.80+BI.tosa"
201-
)
202-
assert os.path.exists(
203-
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag6_TOSA-0.80+BI.json"
204-
)
198+
assert os.path.exists(test_collate_dir)
199+
200+
for file in os.listdir(test_collate_dir):
201+
assert file.endswith(("TOSA-0.80+BI.json", "TOSA-0.80+BI.tosa"))
205202

206203
os.environ.pop("TOSA_TESTCASES_BASE_PATH")
207204
shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@ def test_linear_residaul_tosa_MI(test_data: input_t1):
145145
pipeline.run()
146146

147147

148-
@common.parametrize("test_data", test_data)
148+
@common.parametrize(
149+
"test_data",
150+
test_data,
151+
{"3d_rand": "MLETORCH-855: Issue with Quantization folding."},
152+
strict=False,
153+
)
149154
def test_linear_residual_tosa_BI(test_data: input_t1):
150155
pipeline = TosaPipelineBI[input_t1](
151156
LinearResidualModule(),

backends/arm/test/models/test_nn_functional.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,15 @@ def test_nn_functional_MI(test_data):
104104
raise e
105105

106106

107-
@parametrize("test_data", module_tests)
107+
x_fails = {
108+
"normalize": "MLETORCH-852: Support aten.index_put.default",
109+
"cosine_similarity": "MLETORCH-854: Support aten.linalg_vector_norm.default",
110+
"unfold": "Int64 input && MLETORCH-827: Support aten.index.Tensor",
111+
"fold": "Int64 input && MLETORCH-827: Support aten.index_put.default",
112+
}
113+
114+
115+
@parametrize("test_data", module_tests, x_fails, strict=False)
108116
def test_nn_functional_BI(test_data):
109117
module, inputs = test_data
110118
pipeline = TosaPipelineBI[input_t](

0 commit comments

Comments
 (0)