Skip to content

Add pass to properly check if q and dq nodes are implicit #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions backends/xnnpack/operators/op_dequantize_per_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNConvert,
XNNGraph,
Expand All @@ -23,12 +24,7 @@
@register_node_visitor
class OpDeQuantizePerTensor(NodeVisitor):
"""
Dequantize Per Tensor Node visitor. We only insert an XNNPACK node if
this op was found as a graph input or graph output. This is so we
dequantize the input going in. Every other instance of quantize per
tensor is only used as signaling for q params of node inputs, so
we ignore those. This is because xnnpack only supports entire graph
quantization
Dequantize Per Tensor Node visitor
"""

target = "quantized_decomposed.dequantize_per_tensor.default"
Expand All @@ -44,10 +40,9 @@ def define_node(
debug_handle: int,
) -> None:
"""
We only define a node if it is a graph output
We only define a node if it is not an implict dq node
"""
# TODO:@maxren better handle in-graph quantization conversions, this is hacky
if self.is_graph_output(node):
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
dq_input = get_input_node(node, 0)
input_quant_params = QuantParams.from_q_dq_node(node)
# fp32 output
Expand Down
13 changes: 4 additions & 9 deletions backends/xnnpack/operators/op_quantize_per_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
XNNConvert,
XNNGraph,
Expand All @@ -23,12 +24,7 @@
@register_node_visitor
class OpQuantizePerTensor(NodeVisitor):
"""
Quantize Per Tensor Node visitor. We only insert an XNNPACK node if
this op was found as a graph input or graph output. This is so we
quantize the input going in. Every other instance of quantize per
tensor is only used as signaling for q params of node inputs, so
we ignore those. This is because xnnpack only supports entire graph
quantization
Quantize Per Tensor Node visitor
"""

target = "quantized_decomposed.quantize_per_tensor.default"
Expand All @@ -44,11 +40,10 @@ def define_node(
debug_handle: int,
) -> None:
"""
We only define a node if it is a graph input
We only define a node if it is not an implict q node
"""
# TODO:@maxren better handle in-graph quantization conversions, this is hacky
q_input = get_input_node(node, 0)
if self.is_graph_input(q_input):
if not TagImplicitQDqPass.is_tagged_as_implicit_q_dq(node):
input_quant_params = QuantParams.from_q_dq_node(node)
# fp32 input
self.define_tensor(q_input, xnn_graph, vals_to_ids)
Expand Down
15 changes: 15 additions & 0 deletions backends/xnnpack/partition/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
":configs",
":support_patterns",
"//executorch/backends/xnnpack:xnnpack_preprocess",
"//executorch/exir:delegate",
Expand All @@ -34,3 +35,17 @@ runtime.python_library(
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)

runtime.python_library(
name = "configs",
srcs = [
"configs.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/exir:lib",
],
)
122 changes: 122 additions & 0 deletions backends/xnnpack/partition/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ops as exir_ops

###
### Module based partitioners
###

SUPPORTED_OPS = [
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.upsample_bilinear2d.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten._prelu_kernel.default,
exir_ops.edge.aten.slice_copy.Tensor,
]

SUPPORTED_MODULES = [
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.ReLU,
torch.nn.Sigmoid,
torch.nn.Softmax,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.Linear,
torch.nn.functional.linear,
torch.nn.Hardtanh,
torch.nn.MaxPool2d,
torch.nn.LeakyReLU,
torch.nn.ELU,
torch.nn.AvgPool2d,
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
torch.cat,
torch.concat,
torch.concatenate,
]

# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
SUPPORTED_QUANT_OPS = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both?
exir_ops.edge.aten.slice_copy.Tensor,
]

SUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {
op.name()
for op in (
SUPPORTED_QUANT_OPS
+ [
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.linear.default,
]
)
}

# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
SUPPORTED_QUANT_MODULES = [
torch.clamp,
torch.mean,
torch.permute,
torch.permute_copy,
torch.cat,
torch.concat,
torch.concatenate,
torch.nn.Linear,
torch.nn.functional.linear,
# TODO - T158982884
# torch.ao.nn.quantized.reference.modules.linear.Linear,
torch.nn.MaxPool2d,
torch.nn.Conv1d,
torch.nn.functional.conv1d,
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
torch.nn.Conv2d,
torch.nn.functional.conv2d,
torch.nn.functional.pad,
torch.nn.functional.elu,
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.ConstantPad2d,
torch.nn.ELU,
torch.nn.Hardtanh,
torch.nn.ReLU,
torch.nn.functional.relu,
torch.nn.functional.relu_,
torch.nn.functional.leaky_relu,
torch.nn.functional.leaky_relu_,
torch.nn.LeakyReLU,
]

SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES)

# Modules which support dynamic quantization
SUPPORTED_DYN_QUANT_MODULES = [
torch.nn.Linear,
torch.nn.functional.linear,
]
109 changes: 8 additions & 101 deletions backends/xnnpack/partition/xnnpack_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from typing import Any, Callable, cast, Dict, List, Optional, Union

import torch

from executorch.backends.xnnpack.partition.configs import (
SUPPORTED_DYN_QUANT_MODULES,
SUPPORTED_MODULES,
SUPPORTED_OPS,
SUPPORTED_QUANT_MODULES,
SUPPORTED_QUANT_OPS,
)
from executorch.backends.xnnpack.partition.support_patterns import (
get_add_graphs,
get_all_dynamically_quantized_linear_pattern,
Expand Down Expand Up @@ -522,107 +530,6 @@ def __init__(self):
)


###
### Module based partitioners
###

SUPPORTED_OPS = [
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.upsample_bilinear2d.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.max.dim,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.pow.Tensor_Scalar,
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten._prelu_kernel.default,
exir_ops.edge.aten.slice_copy.Tensor,
]

SUPPORTED_MODULES = [
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.ReLU,
torch.nn.Sigmoid,
torch.nn.Softmax,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.Linear,
torch.nn.functional.linear,
torch.nn.Hardtanh,
torch.nn.MaxPool2d,
torch.nn.LeakyReLU,
torch.nn.ELU,
torch.nn.AvgPool2d,
torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr
torch.cat,
torch.concat,
torch.concatenate,
]

# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
SUPPORTED_QUANT_OPS = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.hardtanh.default, # TODO - which one module or op or both?
exir_ops.edge.aten.slice_copy.Tensor,
]

# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
SUPPORTED_QUANT_MODULES = [
torch.clamp,
torch.mean,
torch.permute,
torch.permute_copy,
torch.cat,
torch.concat,
torch.concatenate,
torch.nn.Linear,
torch.nn.functional.linear,
# TODO - T158982884
# torch.ao.nn.quantized.reference.modules.linear.Linear,
torch.nn.MaxPool2d,
torch.nn.Conv1d,
torch.nn.functional.conv1d,
torch.ao.nn.quantized.reference.modules.conv.Conv1d,
torch.nn.Conv2d,
torch.nn.functional.conv2d,
torch.nn.functional.pad,
torch.nn.functional.elu,
torch.ao.nn.quantized.reference.modules.conv.Conv2d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.ConstantPad2d,
torch.nn.ELU,
torch.nn.Hardtanh,
torch.nn.ReLU,
torch.nn.functional.relu,
torch.nn.functional.relu_,
torch.nn.functional.leaky_relu,
torch.nn.functional.leaky_relu_,
torch.nn.LeakyReLU,
]

# Modules which support dynamic quantization
SUPPORTED_DYN_QUANT_MODULES = [
torch.nn.Linear,
torch.nn.functional.linear,
]


class XnnpackFloatingPointPartitioner(Partitioner):
"""
Module and Opname based partitioner for FP32 modules/ops listed in
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ python_library(
"fuse_batch_norm_with_conv.py",
"prelu_reshape_pass.py",
"remove_getitem_op.py",
"tag_implicit_q_dq_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/transforms:lib",
"//executorch/backends/xnnpack/partition:configs",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from executorch.backends.xnnpack.passes.prelu_reshape_pass import PReLUReshapePass
from executorch.backends.xnnpack.passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass

from executorch.exir.passes import PassManager
from executorch.exir.passes.const_prop_pass import ConstPropPass
Expand All @@ -27,5 +28,6 @@
Conv1dUnsqueezePass(),
PReLUReshapePass(),
ChannelsLastTaggedReshapePass(),
TagImplicitQDqPass(),
]
)
Loading