Skip to content

Commit

Permalink
FuseDequantLinearPass to convert dq -> linear into weight_int8packed_mm
Browse files Browse the repository at this point in the history
Differential Revision: D60945766

Pull Request resolved: pytorch#4708
  • Loading branch information
nathanaelsee committed Aug 15, 2024
1 parent 54f8932 commit 938748b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
15 changes: 15 additions & 0 deletions backends/transforms/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,21 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fuse_dequant_linear",
srcs = ["fuse_dequant_linear.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
":utils",
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir:sym_util",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "fuse_view_copy",
srcs = ["fuse_view_copy.py"],
Expand Down
77 changes: 77 additions & 0 deletions backends/transforms/fuse_dequant_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class FuseDequantLinearPass(ExportPass):
"""
Fuses weight dequantize_per_channel nodes with linear nodes into
weight_int8pack_mm nodes, for 8-bit weight-only quantization.
Replaces dq(weight) -> linear(activation, dq) with weight_int8pack_mm
Replaces dq(weight) -> linear(activation, dq, bias) with weight_int8pack_mm -> add
"""

def fuse_dequant_with_linear(
self,
graph_module: torch.fx.GraphModule,
dequant_node: torch.fx.Node,
linear_node: torch.fx.Node,
) -> None:
activations = linear_node.args[0]
bias = None
if len(linear_node.args) > 2:
bias = linear_node.args[2]
quant_weight = dequant_node.args[0]
scale = dequant_node.args[1]

with graph_module.graph.inserting_before(linear_node):
weight_int8pack_mm_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._weight_int8pack_mm.default,
(activations, quant_weight, scale),
)
if bias:
add_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.add.Tensor,
(weight_int8pack_mm_node, bias),
)
linear_node.replace_all_uses_with(add_node)
else:
linear_node.replace_all_uses_with(weight_int8pack_mm_node)
graph_module.graph.erase_node(linear_node)
graph_module.graph.erase_node(dequant_node)

def is_node_target(
self, node: torch.fx.Node, target: torch._ops.OperatorBase
) -> bool:
return node.op == "call_function" and node.target == target

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if self.is_node_target(node, exir_ops.edge.aten.linear.default):
weight_node = node.args[1]
if self.is_node_target(
weight_node,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
):
# only fuse if weight tensor is int8 packed
quant_weight = weight_node.args[0]
if quant_weight.meta["val"].dtype != torch.int8:
continue
self.fuse_dequant_with_linear(graph_module, weight_node, node)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)

0 comments on commit 938748b

Please sign in to comment.