Skip to content

Add lift scalar to constant tensor pass #8313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions backends/qualcomm/_passes/lift_constant_scalar_operands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import torch
from torch import fx
from torch._subclasses import FakeTensor
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix

COMPARE_SCALAR_OPS = {
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
}


def _not_float_tensor(node: fx.Node) -> bool:
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True
return node.meta["val"].dtype != torch.float32


def _not_bool_tensor(node: fx.Node) -> bool:
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
return True
return node.meta["val"].dtype != torch.bool


def lift_constant_scalar_operands(gm: torch.fx.GraphModule) -> None:
# If the node is add(input, constant) and constant is a scalar, we can lift the constant
# and the annotation for quantization will insert q/dq nodes around the lifted constant.
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in (
torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Tensor,
torch.ops.aten.rsub.Scalar,
torch.ops.aten.add_.Scalar,
) + tuple(COMPARE_SCALAR_OPS.keys()):
continue
const_arg = None
non_const_arg = None
for arg in n.args:
if isinstance(arg, torch.fx.Node):
non_const_arg = arg
else:
const_arg = arg
if non_const_arg is None or const_arg is None:
continue

if _not_float_tensor(n) and _not_bool_tensor(n):
continue

if not _not_float_tensor(n):
tensor_constant = torch.tensor(
[const_arg],
dtype=n.meta["val"].dtype,
device=n.meta["val"].device,
)
else:
tensor_constant = torch.tensor(
[const_arg],
dtype=n.args[0].meta["val"].dtype,
device=n.meta["val"].device,
)
tensor_constant_name = get_new_attr_name_with_prefix("_tensor_constant_")(gm)
gm.register_buffer(tensor_constant_name, tensor_constant)

fake_mode = n.meta["val"].fake_mode
with gm.graph.inserting_before(n):
get_attr_node = gm.graph.get_attr(tensor_constant_name)
get_attr_node.meta["val"] = fake_mode.from_tensor(tensor_constant)

if n.target == torch.ops.aten.rsub.Scalar:
n.args = (get_attr_node, non_const_arg) + n.args[2:]
n.target = torch.ops.aten.sub.Tensor
else:
n.args = (non_const_arg, get_attr_node) + n.args[2:]

if n.target == torch.ops.aten.add_.Scalar:
n.args = (non_const_arg, get_attr_node) + n.args[2:]
n.target = torch.ops.aten.add.Tensor

if n.target in tuple(COMPARE_SCALAR_OPS.keys()):
n.args = (non_const_arg, get_attr_node) + n.args[2:]
n.target = COMPARE_SCALAR_OPS[n.target]

gm.recompile()
4 changes: 4 additions & 0 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import torch
from executorch.backends.qualcomm._passes.decompose_einsum import DecomposeEinsum
from executorch.backends.qualcomm._passes.decompose_silu import DecomposeSilu
from executorch.backends.qualcomm._passes.lift_constant_scalar_operands import (
lift_constant_scalar_operands,
)
from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
RecomposePixelUnshuffle,
)
Expand Down Expand Up @@ -224,6 +227,7 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
model = DecomposeSilu()(model).graph_module
model = DecomposeEinsum()(model).graph_module
model = ReplaceInfBuffer()(model).graph_module
lift_constant_scalar_operands(model) # Turn scalar into tensor, such that we can annotate it for quantization
return model

def validate(self, model: GraphModule) -> None:
Expand Down
Loading