Skip to content

Commit 556142c

Browse files
Arm backend: Make NSS unit tests public (#14891)
NSS has now been made public so unit tests have been moved to public executorch repository. Signed-off-by: Michiel Olieslagers <michiel.olieslagers@arm.com>
1 parent 7a785cb commit 556142c

File tree

3 files changed

+152
-1
lines changed

3 files changed

+152
-1
lines changed

backends/arm/requirements-arm-models-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
diffusers[torch] == 0.33.1
6+
diffusers[torch] == 0.33.1

backends/arm/scripts/install_models_for_test.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,16 @@
66

77
set -e
88
pip install -r backends/arm/requirements-arm-models-test.txt
9+
10+
# Install model gym repository
11+
git clone https://github.com/arm/neural-graphics-model-gym.git
12+
cd neural-graphics-model-gym
13+
# Remove model-converter installation from model-gym repository (to prevent overwriting executorch version)
14+
if [[ "$(uname)" == "Darwin" ]]; then
15+
sed -i '' 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml
16+
else
17+
sed -i 's/^model-converter = "ng_model_gym.bin.model_converter_launcher:main"/#&/' pyproject.toml
18+
fi
19+
pip install . --no-deps
20+
cd ..
21+
rm -rf neural-graphics-model-gym
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import pytest
9+
import torch
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
VgfPipeline,
18+
)
19+
20+
from huggingface_hub import hf_hub_download
21+
22+
from ng_model_gym.usecases.nss.model.model_blocks import AutoEncoderV1
23+
24+
input_t = Tuple[torch.Tensor] # Input x
25+
26+
27+
class NSS(torch.nn.Module):
28+
def __init__(self, *args, **kwargs):
29+
super().__init__(*args, **kwargs)
30+
self.auto_encoder = AutoEncoderV1()
31+
32+
33+
def nss() -> AutoEncoderV1:
34+
"""Get an instance of NSS with weights loaded."""
35+
36+
weights = hf_hub_download(
37+
repo_id="Arm/neural-super-sampling", filename="nss_v0.1.0_fp32.pt"
38+
)
39+
40+
nss_model = NSS()
41+
nss_model.load_state_dict(
42+
torch.load(weights, map_location=torch.device("cpu"), weights_only=True),
43+
strict=False,
44+
)
45+
return nss_model.auto_encoder
46+
47+
48+
def example_inputs():
49+
return (torch.randn((1, 12, 544, 960)),)
50+
51+
52+
def test_nss_tosa_FP():
53+
pipeline = TosaPipelineFP[input_t](
54+
nss().eval(),
55+
example_inputs(),
56+
aten_op=[],
57+
exir_op=[],
58+
use_to_edge_transform_and_lower=True,
59+
)
60+
pipeline.add_stage_after("export", pipeline.tester.dump_operator_distribution)
61+
pipeline.run()
62+
63+
64+
def test_nss_tosa_INT():
65+
pipeline = TosaPipelineINT[input_t](
66+
nss().eval(),
67+
example_inputs(),
68+
aten_op=[],
69+
exir_op=[],
70+
use_to_edge_transform_and_lower=True,
71+
)
72+
pipeline.run()
73+
74+
75+
@pytest.mark.skip(reason="No support for aten_upsample_nearest2d_vec on U55")
76+
@common.XfailIfNoCorstone300
77+
def test_nss_u55_INT():
78+
pipeline = EthosU55PipelineINT[input_t](
79+
nss().eval(),
80+
example_inputs(),
81+
aten_ops=[],
82+
exir_ops=[],
83+
run_on_fvp=True,
84+
use_to_edge_transform_and_lower=True,
85+
)
86+
pipeline.run()
87+
88+
89+
@pytest.mark.skip(
90+
reason="Fails at input memory allocation for input shape: [1, 12, 544, 960]"
91+
)
92+
@common.XfailIfNoCorstone320
93+
def test_nss_u85_INT():
94+
pipeline = EthosU85PipelineINT[input_t](
95+
nss().eval(),
96+
example_inputs(),
97+
aten_ops=[],
98+
exir_ops=[],
99+
run_on_fvp=True,
100+
use_to_edge_transform_and_lower=True,
101+
)
102+
pipeline.run()
103+
104+
105+
@pytest.mark.xfail(
106+
reason="[MLETORCH-1430]: Double types are not supported in buffers in MSL"
107+
)
108+
@common.SkipIfNoModelConverter
109+
def test_nss_vgf_FP():
110+
pipeline = VgfPipeline[input_t](
111+
nss().eval(),
112+
example_inputs(),
113+
aten_op=[],
114+
exir_op=[],
115+
tosa_version="TOSA-1.0+FP",
116+
use_to_edge_transform_and_lower=True,
117+
run_on_vulkan_runtime=True,
118+
)
119+
pipeline.run()
120+
121+
122+
@common.SkipIfNoModelConverter
123+
def test_nss_vgf_INT():
124+
pipeline = VgfPipeline[input_t](
125+
nss().eval(),
126+
example_inputs(),
127+
aten_op=[],
128+
exir_op=[],
129+
tosa_version="TOSA-1.0+INT",
130+
symmetric_io_quantization=True,
131+
use_to_edge_transform_and_lower=True,
132+
run_on_vulkan_runtime=True,
133+
)
134+
pipeline.run()
135+
136+
137+
ModelUnderTest = nss().eval()
138+
ModelInputs = example_inputs()

0 commit comments

Comments
 (0)