Skip to content

Commit e1738cc

Browse files
authored
Arm backend: Update NEGATE with TOSA 1.0 support (#10845)
### Summary Add the serialization to TOSA 1.0 where the attributes has moved to input tensors instead. ### Test plan Tested on internal and external CI. Signed-off-by: Per Åstrand <per.astrand@arm.com>
1 parent f785386 commit e1738cc

File tree

1 file changed

+74
-16
lines changed

1 file changed

+74
-16
lines changed

backends/arm/operators/op_neg.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch.fx
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1211
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1312
get_input_qparams,
1413
get_output_qparams,
@@ -21,12 +20,12 @@
2120
from executorch.backends.arm.tosa_mapping import TosaArg
2221

2322

24-
def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, int]:
23+
def get_negate_zero_points(node: torch.fx.Node, is_int8: bool) -> tuple[int, int]:
2524
"""
2625
Returns (input1_zp, output_zp) for TOSA NEGATE.
2726
Must be zero for non-int8 types.
2827
"""
29-
if dtype == ts.DType.INT8:
28+
if is_int8:
3029
return (
3130
get_input_qparams(node)[0].zp,
3231
get_output_qparams(node)[0].zp,
@@ -35,38 +34,43 @@ def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, i
3534

3635

3736
@register_node_visitor
38-
class NegVisitor(NodeVisitor):
37+
class NegVisitor_0_80(NodeVisitor):
3938
target = "aten.neg.default"
4039

41-
supported_dtypes = {
42-
ts.DType.INT8,
43-
ts.DType.INT16,
44-
ts.DType.INT32,
45-
ts.DType.FP16,
46-
ts.DType.BF16,
47-
ts.DType.FP32,
48-
}
40+
tosa_specs = NodeVisitor.tosa_specs_0_80
4941

5042
def __init__(self, *args):
5143
super().__init__(*args)
5244

5345
def define_node(
5446
self,
5547
node: torch.fx.Node,
56-
tosa_graph: ts.TosaSerializer,
48+
tosa_graph: Any,
5749
inputs: List[TosaArg],
5850
output: TosaArg,
5951
) -> None:
52+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
53+
54+
supported_dtypes = {
55+
ts.DType.INT8,
56+
ts.DType.INT16,
57+
ts.DType.INT32,
58+
ts.DType.FP16,
59+
ts.DType.BF16,
60+
ts.DType.FP32,
61+
}
6062

61-
if inputs[0].dtype not in self.supported_dtypes:
63+
if inputs[0].dtype not in supported_dtypes:
6264
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
6365

6466
if inputs[0].dtype != output.dtype:
6567
raise ValueError(
6668
"All inputs and output need same dtype."
6769
f"Got {inputs[0].dtype=}, {output.dtype=}"
6870
)
69-
input_zp, output_zp = get_negate_zero_points(node, inputs[0].dtype)
71+
input_zp, output_zp = get_negate_zero_points(
72+
node, inputs[0].dtype == ts.DType.INT8
73+
)
7074

7175
attr = ts.TosaSerializerAttribute()
7276
attr.NegateAttribute(input1_zp=input_zp, output_zp=output_zp)
@@ -76,3 +80,57 @@ def define_node(
7680
[output.name],
7781
attributes=attr,
7882
)
83+
84+
85+
@register_node_visitor
86+
class NegVisitor(NodeVisitor):
87+
target = "aten.neg.default"
88+
89+
tosa_specs = NodeVisitor.tosa_specs_1_00
90+
91+
def __init__(self, *args):
92+
super().__init__(*args)
93+
94+
def define_node(
95+
self,
96+
node: torch.fx.Node,
97+
tosa_graph: Any,
98+
inputs: List[TosaArg],
99+
output: TosaArg,
100+
) -> None:
101+
import serializer.tosa_serializer as ts # type: ignore
102+
103+
supported_dtypes = {
104+
ts.DType.INT8,
105+
ts.DType.INT16,
106+
ts.DType.INT32,
107+
ts.DType.FP16,
108+
ts.DType.BF16,
109+
ts.DType.FP32,
110+
}
111+
112+
if inputs[0].dtype not in supported_dtypes:
113+
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
114+
115+
if inputs[0].dtype != output.dtype:
116+
raise ValueError(
117+
"All inputs and output need same dtype."
118+
f"Got {inputs[0].dtype=}, {output.dtype=}"
119+
)
120+
input_zp, output_zp = get_negate_zero_points(
121+
node, inputs[0].dtype == ts.DType.INT8
122+
)
123+
124+
input_zp_tensor = tosa_graph.addConst(
125+
(1,), inputs[0].dtype, [input_zp], name=output.name + "_input_zp"
126+
)
127+
128+
output_zp_tensor = tosa_graph.addConst(
129+
(1,), output.dtype, [output_zp], name=output.name + "_output_zp"
130+
)
131+
132+
tosa_graph.addOperator(
133+
ts.TosaOp.Op().NEGATE,
134+
[inputs[0].name, input_zp_tensor.name, output_zp_tensor.name],
135+
[output.name],
136+
)

0 commit comments

Comments
 (0)