@@ -112,31 +112,6 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
112
112
return safe_int_mm (a , b )
113
113
114
114
115
- # If users install cpu-only pytorch, triton won't be available by default
116
- # This pass adds definition of int_scaled_matmul for this case
117
- if intmm_triton :
118
- lib = intmm_triton .lib
119
- else :
120
- lib = torch .library .Library ("torchao" , "FRAGMENT" )
121
- lib .define ("int_scaled_matmul(Tensor a, Tensor b, Tensor scales1) -> Tensor" )
122
-
123
-
124
- @torch .library .impl (lib , "int_scaled_matmul" , "Meta" )
125
- def int_scaled_matmul_meta (a , b , scales1 ):
126
- M , K = a .shape
127
- K , N = b .shape
128
- return torch .empty ((M , N ), device = a .device , dtype = scales1 .dtype )
129
-
130
-
131
- @torch .library .impl (lib , "int_scaled_matmul" , "CPU" )
132
- def int_scaled_matmul_cpu (a , b , scales1 ):
133
- if TORCH_VERSION_AT_LEAST_2_6 :
134
- c = torch ._int_mm (a , b )
135
- return c .to (scales1 .dtype ) * scales1
136
- else :
137
- return safe_int_mm (a , b ) * scales1
138
-
139
-
140
115
def int_scaled_matmul (a : torch .Tensor , b : torch .Tensor , scales1 : torch .Tensor ) -> torch .Tensor :
141
116
"""
142
117
Performs scaled integer matrix multiplication.
@@ -159,10 +134,14 @@ def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -
159
134
assert scales1 .is_contiguous ()
160
135
scales1 = scales1 .expand ((M , N ))
161
136
assert scales1 .dim () == 2
162
- if (
163
- (intmm_triton is not None and AUTOTUNER_ENABLE )
164
- or scales1 .device .type == "cpu"
165
- ):
137
+
138
+ if scales1 .device .type == "cpu" and TORCH_VERSION_AT_LEAST_2_6 :
139
+ # CPU prefers decomposed version of int_scaled_matmul
140
+ # to leverage the fusion capability of Inductor
141
+ c = torch ._int_mm (a , b )
142
+ return c .to (scales1 .dtype ) * scales1
143
+
144
+ if intmm_triton is not None and AUTOTUNER_ENABLE :
166
145
return torch .ops .torchao .int_scaled_matmul (a , b , scales1 )
167
146
168
147
c = safe_int_mm (a , b )
0 commit comments