Skip to content

Commit

Permalink
Add missing quantized_matmul meta kernel (#4343)
Browse files Browse the repository at this point in the history
Summary:

As titled. The kernel will be used by the quantizer but is missing the meta kernel.

Differential Revision: D60070844
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 22, 2024
1 parent f0364e8 commit 31ad1f9
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
lib.define(
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
)

lib.define(
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
Expand All @@ -44,7 +43,6 @@
)

lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)")

lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
Expand All @@ -56,6 +54,13 @@
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
)
lib.define(
"quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
)

m = Library("cadence", "IMPL", "Meta")


Expand Down Expand Up @@ -164,3 +169,45 @@ def quantized_relu_meta(
X_zero_point: torch.Tensor,
):
return X.new_empty(X.size(), dtype=torch.uint8)


@impl(m, "quantized_matmul")
def quantized_matmul_meta(
X: torch.Tensor,
X_zero_point: int,
Y: torch.Tensor,
Y_zero_point: int,
bias: Optional[torch.Tensor],
out_multiplier: int,
out_shift: int,
out_zero_point: int,
transposed: bool = False,
) -> torch.Tensor:
X_size = list(X.size())
Y_size = list(Y.size())

assert len(X_size) == len(
Y_size
), "quantized matmul not supported for tensors of different dimensions"

if len(X_size) == 3:
assert (
X_size[0] == Y_size[0]
), "quantized matmul only supported for batch dimension of same size"
if transposed:
assert X_size[2] == Y_size[2], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[1]]
else:
assert X_size[2] == Y_size[1], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[2]]
elif len(X_size) == 2:
if transposed:
assert X_size[1] == Y_size[1], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[0]]
else:
assert X_size[1] == Y_size[0], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[1]]
else:
raise AssertionError("quantized matmul only supported for 2D or 3D tensors")

return X.new_empty(out_size, dtype=X.dtype)

0 comments on commit 31ad1f9

Please sign in to comment.