4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
# pyre-unsafe
7
- from typing import List
7
+ from typing import Any , List
8
8
9
9
import torch .fx
10
10
11
- import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
12
11
from executorch .backends .arm ._passes .fold_qdq_with_annotated_qparams_pass import (
13
12
get_input_qparams ,
14
13
get_output_qparams ,
21
20
from executorch .backends .arm .tosa_mapping import TosaArg
22
21
23
22
24
- def get_negate_zero_points (node : torch .fx .Node , dtype : ts . DType ) -> tuple [int , int ]:
23
+ def get_negate_zero_points (node : torch .fx .Node , is_int8 : bool ) -> tuple [int , int ]:
25
24
"""
26
25
Returns (input1_zp, output_zp) for TOSA NEGATE.
27
26
Must be zero for non-int8 types.
28
27
"""
29
- if dtype == ts . DType . INT8 :
28
+ if is_int8 :
30
29
return (
31
30
get_input_qparams (node )[0 ].zp ,
32
31
get_output_qparams (node )[0 ].zp ,
@@ -35,38 +34,43 @@ def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, i
35
34
36
35
37
36
@register_node_visitor
38
- class NegVisitor (NodeVisitor ):
37
+ class NegVisitor_0_80 (NodeVisitor ):
39
38
target = "aten.neg.default"
40
39
41
- supported_dtypes = {
42
- ts .DType .INT8 ,
43
- ts .DType .INT16 ,
44
- ts .DType .INT32 ,
45
- ts .DType .FP16 ,
46
- ts .DType .BF16 ,
47
- ts .DType .FP32 ,
48
- }
40
+ tosa_specs = NodeVisitor .tosa_specs_0_80
49
41
50
42
def __init__ (self , * args ):
51
43
super ().__init__ (* args )
52
44
53
45
def define_node (
54
46
self ,
55
47
node : torch .fx .Node ,
56
- tosa_graph : ts . TosaSerializer ,
48
+ tosa_graph : Any ,
57
49
inputs : List [TosaArg ],
58
50
output : TosaArg ,
59
51
) -> None :
52
+ import tosa_tools .v0_80 .serializer .tosa_serializer as ts # type: ignore
53
+
54
+ supported_dtypes = {
55
+ ts .DType .INT8 ,
56
+ ts .DType .INT16 ,
57
+ ts .DType .INT32 ,
58
+ ts .DType .FP16 ,
59
+ ts .DType .BF16 ,
60
+ ts .DType .FP32 ,
61
+ }
60
62
61
- if inputs [0 ].dtype not in self . supported_dtypes :
63
+ if inputs [0 ].dtype not in supported_dtypes :
62
64
raise ValueError (f"Unsupported dtype for NEGATE: { inputs [0 ].dtype } " )
63
65
64
66
if inputs [0 ].dtype != output .dtype :
65
67
raise ValueError (
66
68
"All inputs and output need same dtype."
67
69
f"Got { inputs [0 ].dtype = } , { output .dtype = } "
68
70
)
69
- input_zp , output_zp = get_negate_zero_points (node , inputs [0 ].dtype )
71
+ input_zp , output_zp = get_negate_zero_points (
72
+ node , inputs [0 ].dtype == ts .DType .INT8
73
+ )
70
74
71
75
attr = ts .TosaSerializerAttribute ()
72
76
attr .NegateAttribute (input1_zp = input_zp , output_zp = output_zp )
@@ -76,3 +80,57 @@ def define_node(
76
80
[output .name ],
77
81
attributes = attr ,
78
82
)
83
+
84
+
85
+ @register_node_visitor
86
+ class NegVisitor (NodeVisitor ):
87
+ target = "aten.neg.default"
88
+
89
+ tosa_specs = NodeVisitor .tosa_specs_1_00
90
+
91
+ def __init__ (self , * args ):
92
+ super ().__init__ (* args )
93
+
94
+ def define_node (
95
+ self ,
96
+ node : torch .fx .Node ,
97
+ tosa_graph : Any ,
98
+ inputs : List [TosaArg ],
99
+ output : TosaArg ,
100
+ ) -> None :
101
+ import serializer .tosa_serializer as ts # type: ignore
102
+
103
+ supported_dtypes = {
104
+ ts .DType .INT8 ,
105
+ ts .DType .INT16 ,
106
+ ts .DType .INT32 ,
107
+ ts .DType .FP16 ,
108
+ ts .DType .BF16 ,
109
+ ts .DType .FP32 ,
110
+ }
111
+
112
+ if inputs [0 ].dtype not in supported_dtypes :
113
+ raise ValueError (f"Unsupported dtype for NEGATE: { inputs [0 ].dtype } " )
114
+
115
+ if inputs [0 ].dtype != output .dtype :
116
+ raise ValueError (
117
+ "All inputs and output need same dtype."
118
+ f"Got { inputs [0 ].dtype = } , { output .dtype = } "
119
+ )
120
+ input_zp , output_zp = get_negate_zero_points (
121
+ node , inputs [0 ].dtype == ts .DType .INT8
122
+ )
123
+
124
+ input_zp_tensor = tosa_graph .addConst (
125
+ (1 ,), inputs [0 ].dtype , [input_zp ], name = output .name + "_input_zp"
126
+ )
127
+
128
+ output_zp_tensor = tosa_graph .addConst (
129
+ (1 ,), output .dtype , [output_zp ], name = output .name + "_output_zp"
130
+ )
131
+
132
+ tosa_graph .addOperator (
133
+ ts .TosaOp .Op ().NEGATE ,
134
+ [inputs [0 ].name , input_zp_tensor .name , output_zp_tensor .name ],
135
+ [output .name ],
136
+ )
0 commit comments