Skip to content

Commit

Permalink
Search graph for quantization parameters (#6690)
Browse files Browse the repository at this point in the history
* Search graph for quantization nodes

Generalizes the search for quantization parameters.
The idea is to make a graph like this a valid quantized graph:

dq -> view -> transpose -> some_op
			   ^
                          /
dq ------> expand -------/

For a subset of operations 'passable_op' it is is allowed to "pass
through" the op when searching for qparams. If multiple qparams
are encounterd in one search, they are asserted to be equal.

Signed-off-by: Erik Lundell <erik.lundell@arm.com>
  • Loading branch information
Erik-Lundell authored Nov 11, 2024
1 parent 793f17e commit 7fcd0af
Show file tree
Hide file tree
Showing 26 changed files with 317 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
get_first_fake_tensor,
insert_q_dq_pair,
)
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
return args[0]


register_passable_op(torch.ops.passthrough_to_tosa._transpose)


class AnnotateChannelsLastDimOrder(ExportPass):
"""
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
Expand Down
14 changes: 1 addition & 13 deletions backends/arm/_passes/insert_squeeze_after_sum_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

import torch
import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
sum(dims, keep_dim = False)
After pass:
sum(dims, keep_dim = True)
(q)
(dq)
squeeze(dim = dims)
"""

Expand All @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

dim_list = cast(list[int], sum_node.args[1])
quantized = is_quant_node(sum_node)
if quantized:
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
qparams = qparams + (torch.int8,)
else:
qparams = None

# Add keep_dim = True arg to sum node.
sum_node.args = sum_node.args[0:2] + (True,)
Expand All @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
)
sum_node.replace_all_uses_with(squeeze_node)
squeeze_node.args = (sum_node, dim_list)
if quantized:
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import cast, Optional

import torch.fx
from executorch.backends.arm.tosa_quant_utils import is_quant_node
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload
Expand Down Expand Up @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
slice_node = graph.create_node(
"call_function", self.slice_op, (last_node,) + args
)
if is_quant_node(last_node):
if is_node_quantized(last_node):
q_params = last_node.args[1:]
dq_node = insert_q_dq_pair(
graph_module.graph, slice_node, q_params
Expand Down
16 changes: 10 additions & 6 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
build_rescale,
get_quant_arg_downstream,
get_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import get_two_inputs
from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -42,8 +46,10 @@ def define_node(
# For INT8, we need to get the zero points and add an intermediate tensor
# for a later rescale.
if is_quant_node:
input0_zp = get_quant_node_args(input0).zp
input1_zp = get_quant_node_args(input1).zp
input0_q_params = get_quant_arg_upstream(input0)
input1_q_params = get_quant_arg_upstream(input1)
input0_zp = input0_q_params.zp
input1_zp = input1_q_params.zp
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
bmm_output_name = bmm_result.name
else:
Expand All @@ -63,9 +69,7 @@ def define_node(

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if is_quant_node:
input0_q_params = get_quant_node_args(input0)
input1_q_params = get_quant_node_args(input1)
output_q_params = get_quant_node_args(list(node.users)[0])
output_q_params = get_quant_arg_downstream(list(node.users)[0])

final_output_scale = (
input0_q_params.scale * input1_q_params.scale
Expand Down
20 changes: 11 additions & 9 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List
from typing import List

import serializer.tosa_serializer as ts
import torch
Expand All @@ -15,9 +15,10 @@
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
build_rescale_conv_output,
get_quant_node_args,
get_quant_arg_downstream,
get_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape

from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -82,7 +83,7 @@ def define_node(
)

input_zp = (
get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
)

attr.ConvAttribute(
Expand Down Expand Up @@ -158,9 +159,10 @@ def define_node(
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
# Get scale_factor from input, weight, and output.
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
output_qargs = get_quant_arg_downstream(list(node.users)[0])

build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
Expand All @@ -169,6 +171,6 @@ def define_node(
actual_out_type,
input_scale,
weight_scale,
output_scale,
output_zp,
output_qargs.scale,
output_qargs.zp,
)
7 changes: 4 additions & 3 deletions backends/arm/operators/op_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
get_quant_arg_downstream,
get_quant_arg_upstream,
QuantArgs,
quantize_value,
)
Expand Down Expand Up @@ -48,9 +49,9 @@ def define_node(

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
in_quantargs = get_quant_arg_upstream(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)
out_quantargs = get_quant_arg_downstream(output_node)

table = exp_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
Expand Down
11 changes: 6 additions & 5 deletions backends/arm/operators/op_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_downstream,
quantize_value,
)
from executorch.backends.arm.tosa_utils import tosa_shape
from torch.fx import Node

Expand All @@ -39,10 +42,8 @@ def define_node(

value = inputs[1].number
if is_quant_node:
qargs = get_quant_node_args(list(node.users)[0])
qvalue = np.clip(
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
)
qargs = get_quant_arg_downstream(list(node.users)[0])
qvalue = quantize_value(value, qargs)
dtype = ts.DType.INT8
data = np.full(shape, qvalue, dtype=np.int8)
else:
Expand Down
13 changes: 7 additions & 6 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_upstream,
quantize_value,
)
from serializer.tosa_serializer import TosaOp


Expand All @@ -37,12 +40,10 @@ def define_node(

if is_quant_node:
# Get quant parameters
scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0])
qargs = get_quant_arg_upstream(node.all_input_nodes[0])
# Convert to quantized representation
clamp_min_qs = round((inputs[1].number / scale) + zp)
clamp_min_qs = max(clamp_min_qs, qmin)
clamp_max_qs = round((inputs[2].number / scale) + zp)
clamp_max_qs = min(clamp_max_qs, qmax)
clamp_min_qs = quantize_value(inputs[1].number, qargs)
clamp_max_qs = quantize_value(inputs[2].number, qargs)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
Expand Down
7 changes: 4 additions & 3 deletions backends/arm/operators/op_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from executorch.backends.arm.tosa_quant_utils import (
dequantize_value,
get_quant_node_args,
get_quant_arg_downstream,
get_quant_arg_upstream,
QuantArgs,
quantize_value,
)
Expand Down Expand Up @@ -49,9 +50,9 @@ def define_node(

# Create attribute for 8 bit table lookup.
input_node = node.all_input_nodes[0]
in_quantargs = get_quant_node_args(input_node)
in_quantargs = get_quant_arg_upstream(input_node)
output_node = list(node.users)[0]
out_quantargs = get_quant_node_args(output_node)
out_quantargs = get_quant_arg_downstream(output_node)

table = log_table_8bit(in_quantargs, out_quantargs)
table_attr = ts.TosaSerializerAttribute()
Expand Down
9 changes: 6 additions & 3 deletions backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import get_quant_node_args
from executorch.backends.arm.tosa_utils import (
get_quant_arg_downstream,
get_quant_arg_upstream,
)

from serializer.tosa_serializer import TosaOp

Expand Down Expand Up @@ -54,8 +57,8 @@ def define_node(
output_zp = 0

if is_quant_node:
input_zp = get_quant_node_args(node.all_input_nodes[0]).zp
output_zp = get_quant_node_args(list(node.users)[0]).zp
input_zp = get_quant_arg_upstream(node.all_input_nodes[0]).zp
output_zp = get_quant_arg_downstream(list(node.users)[0]).zp

attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
Expand Down
16 changes: 10 additions & 6 deletions backends/arm/operators/op_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
from executorch.backends.arm.tosa_quant_utils import (
build_rescale,
get_quant_arg_downstream,
get_quant_arg_upstream,
)
from executorch.backends.arm.tosa_utils import (
build_reshape,
expand_dims,
Expand Down Expand Up @@ -54,8 +58,8 @@ def define_node(
# For INT8, we need to get the zero point, otherwise it is 0
input0_zp, input1_zp = 0, 0
if is_quant_node:
input0_zp = get_quant_node_args(input0).zp
input1_zp = get_quant_node_args(input1).zp
input0_zp = get_quant_arg_upstream(input0).zp
input1_zp = get_quant_arg_upstream(input1).zp

mat_mul_result = tosa_graph.addIntermediate(
output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype
Expand Down Expand Up @@ -86,9 +90,9 @@ def define_node(

# As INT8 accumulates into INT32, we need to rescale it back to INT8
if is_quant_node:
input0_q_params = get_quant_node_args(input0)
input1_q_params = get_quant_node_args(input1)
output_q_params = get_quant_node_args(list(node.users)[0])
input0_q_params = get_quant_arg_upstream(input0)
input1_q_params = get_quant_arg_upstream(input1)
output_q_params = get_quant_arg_downstream(list(node.users)[0])

final_output_scale = (
input0_q_params.scale * input1_q_params.scale
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def define_node(
if is_quant_node:
input_A = inputs[0]
input_B = inputs[1]
input_A_qargs = tqutils.get_quant_node_args(
input_A_qargs = tqutils.get_quant_arg_upstream(
cast(torch.fx.Node, node.args[0])
)
input_B_qargs = tqutils.get_quant_node_args(
input_B_qargs = tqutils.get_quant_arg_upstream(
cast(torch.fx.Node, node.args[1])
)

Expand Down
17 changes: 11 additions & 6 deletions backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import torch.fx
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
get_quant_arg_dtype,
get_quant_node_args,
is_quant_arg,
get_quant_arg_upstream,
get_quantized_node_output_dtype,
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
is_bias_node_for_quantized_conv,
map_dtype,
tosa_shape,
)
from torch.export.exported_program import ExportedProgram
Expand All @@ -41,7 +42,11 @@ def process_inputs(
tensor = ts.TosaSerializerTensor(
inputs[0].name,
tosa_shape(input_shape, input_dim_order),
get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype,
(
map_dtype(get_quantized_node_output_dtype(node))
if is_node_quantized(node)
else inputs[0].dtype
),
data=None,
placeholderFilename=inputs[0].name + ".npy",
)
Expand All @@ -63,8 +68,8 @@ def process_quantized_bias(
_,
) = consumer_node.all_input_nodes

input_node_scale = get_quant_node_args(input_node).scale
weight_node_scale = get_quant_node_args(weight_node).scale
input_node_scale = get_quant_arg_upstream(input_node).scale
weight_node_scale = get_quant_arg_upstream(weight_node).scale
bias_values_quantized = (
(parameter_values / (input_node_scale * weight_node_scale))
.round()
Expand Down
Loading

0 comments on commit 7fcd0af

Please sign in to comment.