From 31ad1f9357c5d09fb9c4ad0cd0c816b619daad9d Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Mon, 22 Jul 2024 14:43:59 -0700 Subject: [PATCH] Add missing quantized_matmul meta kernel (#4343) Summary: As titled. The kernel will be used by the quantizer but is missing the meta kernel. Differential Revision: D60070844 --- backends/cadence/aot/ops_registrations.py | 51 ++++++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 44c9885fa0..c877a7149d 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -31,7 +31,6 @@ lib.define( "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" ) - lib.define( "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) @@ -44,7 +43,6 @@ ) lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)") - lib.define( "quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) @@ -56,6 +54,13 @@ "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" +) +lib.define( + "quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" +) + m = Library("cadence", "IMPL", "Meta") @@ -164,3 +169,45 @@ def quantized_relu_meta( X_zero_point: torch.Tensor, ): return X.new_empty(X.size(), dtype=torch.uint8) + + +@impl(m, "quantized_matmul") +def quantized_matmul_meta( + X: torch.Tensor, + X_zero_point: int, + Y: torch.Tensor, + Y_zero_point: int, + bias: Optional[torch.Tensor], + out_multiplier: int, + out_shift: int, + out_zero_point: int, + transposed: bool = False, +) -> torch.Tensor: + X_size = list(X.size()) + Y_size = list(Y.size()) + + assert len(X_size) == len( + Y_size + ), "quantized matmul not supported for tensors of different dimensions" + + if len(X_size) == 3: + assert ( + X_size[0] == Y_size[0] + ), "quantized matmul only supported for batch dimension of same size" + if transposed: + assert X_size[2] == Y_size[2], "matrices cannot be multiplied" + out_size = X_size[:2] + [Y_size[1]] + else: + assert X_size[2] == Y_size[1], "matrices cannot be multiplied" + out_size = X_size[:2] + [Y_size[2]] + elif len(X_size) == 2: + if transposed: + assert X_size[1] == Y_size[1], "matrices cannot be multiplied" + out_size = [X_size[0], Y_size[0]] + else: + assert X_size[1] == Y_size[0], "matrices cannot be multiplied" + out_size = [X_size[0], Y_size[1]] + else: + raise AssertionError("quantized matmul only supported for 2D or 3D tensors") + + return X.new_empty(out_size, dtype=X.dtype)