Skip to content

Commit

Permalink
Support qmatmul with different dims tensors (#4438)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #4438

MobileBERT exposes an issue in our kernel, where tensors have compatible (for PyTorch) but different batch dimensions.

This diff changes the meta kernel to support that (the kernel can already do it).

Reviewed By: dulinriley

Differential Revision: D60314979

fbshipit-source-id: a0cde9d328098992787c353611ece64223d6c739
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed Jul 29, 2024
1 parent e087ac8 commit f695f8e
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from math import prod
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -186,28 +187,29 @@ def quantized_matmul_meta(
X_size = list(X.size())
Y_size = list(Y.size())

assert len(X_size) == len(
Y_size
), "quantized matmul not supported for tensors of different dimensions"

if len(X_size) == 3:
assert (
X_size[0] == Y_size[0]
), "quantized matmul only supported for batch dimension of same size"
if transposed:
assert X_size[2] == Y_size[2], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[1]]
else:
assert X_size[2] == Y_size[1], "matrices cannot be multiplied"
out_size = X_size[:2] + [Y_size[2]]
elif len(X_size) == 2:
if transposed:
assert X_size[1] == Y_size[1], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[0]]
else:
assert X_size[1] == Y_size[0], "matrices cannot be multiplied"
out_size = [X_size[0], Y_size[1]]
# Get the batch dimensions for both tensors
X_batch_dims = X_size[:-2]
Y_batch_dims = Y_size[:-2]

# If they don't match, check that they're compatible
if X_batch_dims != Y_batch_dims:
assert prod(X_batch_dims) == prod(
Y_batch_dims
), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}"

# Get the matmul output size
if transposed:
assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-2]]
else:
raise AssertionError("quantized matmul only supported for 2D or 3D tensors")
assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied"
mat_size = [X_size[-2], Y_size[-1]]

# Combine the larger batch dimensions with the matmul output size
out_size = (
X_batch_dims + mat_size
if len(X_batch_dims) > len(Y_batch_dims)
else Y_batch_dims + mat_size
)

return X.new_empty(out_size, dtype=X.dtype)

0 comments on commit f695f8e

Please sign in to comment.