Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing quantized_matmul meta kernel #4343

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
lib.define(
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
)

lib.define(
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
Expand All @@ -44,7 +43,6 @@
)

lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)")

lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)
Expand All @@ -56,6 +54,13 @@
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
)
lib.define(
"quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)"
)

m = Library("cadence", "IMPL", "Meta")


Expand Down Expand Up @@ -164,3 +169,45 @@ def quantized_relu_meta(
X_zero_point: torch.Tensor,
):
return X.new_empty(X.size(), dtype=torch.uint8)


@impl(m, "quantized_matmul")
def quantized_matmul_meta(
X: torch.Tensor,
X_zero_point: int,
Y: torch.Tensor,
Y_zero_point: int,
bias: Optional[torch.Tensor],
out_multiplier: int,
out_shift: int,
out_zero_point: int,
transposed: bool = False,
) -> torch.Tensor:
X_size = list(X.size())
Y_size = list(Y.size())

assert len(X_size) == len(
Y_size
), "quantized matmul not supported for tensors of different dimensions"

if len(X_size) == 3:
assert (
X_size[0] == Y_size[0]
), "quantized matmul only supported for batch dimension of same size"
if transposed:
assert X_size[2] == Y_size[2], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[1]]
else:
assert X_size[2] == Y_size[1], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[2]]
elif len(X_size) == 2:
if transposed:
assert X_size[1] == Y_size[1], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[0]]
else:
assert X_size[1] == Y_size[0], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[1]]
else:
raise AssertionError("quantized matmul only supported for 2D or 3D tensors")

return X.new_empty(out_size, dtype=X.dtype)
Loading