Skip to content

Commit 97dfea8

Browse files
committed
Decompose scaled_int_mm for CPU
1 parent 9f366a9 commit 97dfea8

File tree

2 files changed

+15
-29
lines changed

2 files changed

+15
-29
lines changed

torchao/kernel/intmm.py

+8-29
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,6 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
112112
return safe_int_mm(a, b)
113113

114114

115-
# If users install cpu-only pytorch, triton won't be available by default
116-
# This pass adds definition of int_scaled_matmul for this case
117-
if intmm_triton:
118-
lib = intmm_triton.lib
119-
else:
120-
lib = torch.library.Library("torchao", "FRAGMENT")
121-
lib.define("int_scaled_matmul(Tensor a, Tensor b, Tensor scales1) -> Tensor")
122-
123-
124-
@torch.library.impl(lib, "int_scaled_matmul", "Meta")
125-
def int_scaled_matmul_meta(a, b, scales1):
126-
M, K = a.shape
127-
K, N = b.shape
128-
return torch.empty((M, N), device=a.device, dtype=scales1.dtype)
129-
130-
131-
@torch.library.impl(lib, "int_scaled_matmul", "CPU")
132-
def int_scaled_matmul_cpu(a, b, scales1):
133-
if TORCH_VERSION_AT_LEAST_2_6:
134-
c = torch._int_mm(a, b)
135-
return c.to(scales1.dtype) * scales1
136-
else:
137-
return safe_int_mm(a, b) * scales1
138-
139-
140115
def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor:
141116
"""
142117
Performs scaled integer matrix multiplication.
@@ -159,10 +134,14 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -
159134
assert scales1.is_contiguous()
160135
scales1 = scales1.expand((M, N))
161136
assert scales1.dim() == 2
162-
if (
163-
(intmm_triton is not None and AUTOTUNER_ENABLE)
164-
or scales1.device.type == "cpu"
165-
):
137+
138+
if scales1.device.type == "cpu" and TORCH_VERSION_AT_LEAST_2_6:
139+
# CPU prefers decomposed version of int_scaled_matmul
140+
# to leverage the fusion capability of Inductor
141+
c = torch._int_mm(a, b)
142+
return c.to(scales1.dtype) * scales1
143+
144+
if intmm_triton is not None and AUTOTUNER_ENABLE:
166145
return torch.ops.torchao.int_scaled_matmul(a, b, scales1)
167146

168147
c = safe_int_mm(a, b)

torchao/kernel/intmm_triton.py

+7
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,13 @@ def int_matmul_cuda(a, b):
334334
return int_matmul_kernel(a, b, c, best_config)
335335

336336

337+
@torch.library.impl(lib, "int_scaled_matmul", "Meta")
338+
def int_scaled_matmul_meta(a, b, scales1):
339+
M, K = a.shape
340+
K, N = b.shape
341+
return torch.empty((M, N), device=a.device, dtype=scales1.dtype)
342+
343+
337344
@torch.library.impl(lib, "int_scaled_matmul", "CUDA")
338345
def int_scaled_matmul_cuda(a, b, scales1):
339346
# Check constraints.

0 commit comments

Comments
 (0)