Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion backends/arm/_passes/rewrite_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.arm.tosa.utils import get_resize_parameters
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -52,7 +53,9 @@ def call(self, graph_module):
node.replace_all_uses_with(tosa_resize_node)
graph_module.graph.erase_node(node)
input_dtype = get_first_fake_tensor(x).dtype
if input_dtype == torch.int8 and resize_mode == "bilinear":
if (
input_dtype == torch.int8 or input_dtype == torch.int16
) and resize_mode == "bilinear":
input_size = get_first_fake_tensor(x).shape
input_size_xy = input_size[2:]
output_size = get_first_fake_tensor(node).shape
Expand All @@ -71,6 +74,11 @@ def call(self, graph_module):
exir_ops.backend.tosa.RESCALE.default,
)
tosa_resize_node.replace_all_uses_with(rescale_node)
if input_dtype == torch.int16:
tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = (
TosaSpecialDtype.INT48
)

rescale_node.args = (
tosa_resize_node,
output_dtype,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
output.tosa_spec,
)

if inputs[0].dtype == ts.DType.INT8:
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
accumulator_type = ts.DType.INT32
input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].get_zp_per_tensor()
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes:
return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist()
elif dtype == torch.int8:
return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist()
elif dtype == torch.int16:
return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist()
else:
raise ValueError(f"Unsupported dtype for to_bytes: {dtype}")

Expand All @@ -89,7 +91,7 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.FP16, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
output.tosa_spec,
)

Expand Down
6 changes: 6 additions & 0 deletions backends/arm/operators/op_constant_pad_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def define_node(
[inputs[0], output],
[
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
ts.DType.BOOL,
Expand All @@ -62,6 +63,11 @@ def define_node(
qargs = input_qparams[0]
pad_const_val = qargs.quantize_value(inputs[2].number).item()
pad_const_dtype = ts.DType.INT8
elif inputs[0].dtype == ts.DType.INT16:
input_qparams = get_input_qparams(node)
qargs = input_qparams[0]
pad_const_val = qargs.quantize_value(inputs[2].number).item()
pad_const_dtype = ts.DType.INT16
else:
pad_const_val = inputs[2].number
pad_const_dtype = inputs[0].dtype
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
output.tosa_spec,
)

Expand Down
18 changes: 16 additions & 2 deletions backends/arm/operators/op_tosa_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,27 @@ def define_node(
resize_mode = ts.ResizeMode.NEAREST
align_corners = False
validate_same_dtype(self.target, [inputs[0], output], ts)

valid_dtypes = []
if self.tosa_spec.support_integer():
valid_dtypes.extend(
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.INT48]
)

if self.tosa_spec.support_float():
valid_dtypes.extend(
[
ts.DType.FP16,
ts.DType.FP32,
]
)

validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP16, ts.DType.FP32],
valid_dtypes,
output.tosa_spec,
)

# tosa_shape output is NHWC, take HW
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
1:3
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .quantization_config import QuantizationConfig # noqa # usort: skip
from .arm_quantizer import ( # noqa
EthosUQuantizer,
get_symmetric_a16w8_quantization_config,
get_symmetric_quantization_config,
TOSAQuantizer,
VgfQuantizer,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"TOSAQuantizer",
"EthosUQuantizer",
"VgfQuantizer",
"get_symmetric_a16w8_quantization_config",
"get_symmetric_quantization_config",
]

Expand Down
50 changes: 50 additions & 0 deletions backends/arm/test/ops/test_adaptive_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def test_adaptive_avg_pool2d_tosa_INT(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules)
def test_adaptive_avg_pool2d_tosa_INT_a16w8(test_module):
"""Test adaptive_avg_pool2d with int16 I/O quantization for TOSA INT."""
model, input_tensor = test_module()
pipeline = TosaPipelineINT[input_t](
model,
input_tensor,
aten_op=[],
exir_op=exir_op,
tosa_extensions=["int16"],
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone300
def test_adaptive_avg_pool2d_u55_INT(test_module):
Expand All @@ -150,6 +164,27 @@ def test_adaptive_avg_pool2d_u55_INT(test_module):
pipeline.run()


# Remove high_channel_count & output_1x1_from_19 due to 2MB SRAM access on U55
u55_test_modules = test_modules
for key in ["high_channel_count", "output_1x1_from_19"]:
u55_test_modules.pop(key)


@common.parametrize("test_module", u55_test_modules)
@common.XfailIfNoCorstone300
def test_adaptive_avg_pool2d_16a8w_u55_INT16(test_module):
"""Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
pipeline = EthosU55PipelineINT[input_t](
model,
input_tensor,
aten_ops=[],
exir_ops=exir_op,
a16w8_quantization=True,
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone320
def test_adaptive_avg_pool2d_u85_INT(test_module):
Expand All @@ -164,6 +199,21 @@ def test_adaptive_avg_pool2d_u85_INT(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone320
def test_adaptive_avg_pool2d_16a8w_u85_INT16(test_module):
"""Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
pipeline = EthosU85PipelineINT[input_t](
model,
input_tensor,
aten_ops=[],
exir_ops=exir_op,
a16w8_quantization=True,
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.SkipIfNoModelConverter
def test_adaptive_avg_pool2d_vgf_FP(test_module):
Expand Down
49 changes: 49 additions & 0 deletions backends/arm/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def test_avg_pool2d_tosa_INT(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules)
def test_avg_pool2d_tosa_INT_a16w8(test_module):
"""Test avg_pool2d operation with int16 I/O quantization for TOSA INT."""
model, input_tensor = test_module()
pipeline = TosaPipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op,
tosa_extensions=["int16"],
run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"),
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone300
def test_avg_pool2d_u55_INT(test_module):
Expand All @@ -155,6 +170,23 @@ def test_avg_pool2d_u55_INT(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone300
def test_avg_pool2d_16a8w_u55_INT16(test_module):
"""Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
pipeline = EthosU55PipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone320
def test_avg_pool2d_u85_INT(test_module):
Expand All @@ -169,6 +201,23 @@ def test_avg_pool2d_u85_INT(test_module):
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone320
def test_avg_pool2d_16a8w_u85_INT16(test_module):
"""Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
pipeline = EthosU85PipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.SkipIfNoModelConverter
def test_avg_pool2d_vgf_FP(test_module):
Expand Down
54 changes: 54 additions & 0 deletions backends/arm/test/ops/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ def test_clamp_tosa_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_clamp_tosa_INT_a16w8(test_data):
"""Test clamp operation with int16 I/O quantization for TOSA INT."""
input_tensor, min_val, max_val = test_data()
model = Clamp(min_val, max_val)
pipeline = TosaPipelineINT[input_t](
model,
(input_tensor,),
aten_op,
exir_op,
tosa_extensions=["int16"],
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
def test_clamp_u55_INT(test_data):
Expand All @@ -102,6 +118,25 @@ def test_clamp_u55_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone300
def test_clamp_16a8w_u55_INT16(test_data):
"""Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
input_tensor, min_val, max_val = test_data()
model = Clamp(min_val, max_val)
pipeline = EthosU55PipelineINT[input_t](
model,
(input_tensor,),
aten_op,
exir_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
def test_clamp_u85_INT(test_data):
Expand All @@ -120,6 +155,25 @@ def test_clamp_u85_INT(test_data):
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.XfailIfNoCorstone320
def test_clamp_16a8w_u85_INT16(test_data):
"""Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
input_tensor, min_val, max_val = test_data()
model = Clamp(min_val, max_val)
pipeline = EthosU85PipelineINT[input_t](
model,
(input_tensor,),
aten_op,
exir_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.change_args("run_method_and_compare_outputs", qtol=1)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_clamp_vgf_FP(test_data):
Expand Down
Loading
Loading