@@ -87,9 +87,11 @@ def get_weights(self, weights: "Weights", prefix: str):
87
87
88
88
if w .dtype == torch .float8_e4m3fn :
89
89
# FP8 branch
90
- scale = weights .get_tensor (
91
- f"{ prefix } .weight_scale" , to_dtype = False
92
- ).reshape (- 1 )
90
+ scale = (
91
+ weights .get_tensor (f"{ prefix } .weight_scale" , to_dtype = False )
92
+ .reshape (- 1 )
93
+ .expand (w .shape [0 ])
94
+ )
93
95
return Fp8Weight (
94
96
weight = w ,
95
97
weight_scale = scale ,
@@ -113,9 +115,16 @@ def get_weights_col_packed(
113
115
114
116
if w .dtype == torch .float8_e4m3fn :
115
117
# FP8 branch
116
- scale = weights .get_packed_sharded (
117
- f"{ prefix } .weight_scale" , dim = 0 , block_sizes = block_sizes , to_dtype = False
118
- ).reshape (- 1 )
118
+ scale = weights .get_tensor (f"{ prefix } .weight_scale" , to_dtype = False )
119
+ if scale .numel () > 1 :
120
+ scale = weights .get_packed_sharded (
121
+ f"{ prefix } .weight_scale" ,
122
+ dim = 0 ,
123
+ block_sizes = block_sizes ,
124
+ to_dtype = False ,
125
+ )
126
+ scale = scale .reshape (- 1 ).expand (w .shape [0 ])
127
+
119
128
return Fp8Weight (
120
129
weight = w ,
121
130
weight_scale = scale ,
@@ -132,16 +141,19 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
132
141
w = [
133
142
weights .get_sharded (f"{ p } .weight" , dim = 0 , to_device = False ) for p in prefixes
134
143
]
144
+ shapes = [x .shape for x in w ]
145
+
135
146
# Concat then send to the device
136
147
w = torch .cat (w , dim = dim ).to (weights .device )
137
148
138
149
# FP8 branch
139
150
if w .dtype == torch .float8_e4m3fn :
140
151
scale = [
141
- weights . get_sharded ( f"{ p } .weight_scale" , dim = 0 , to_dtype = False )
142
- for p in prefixes
152
+ _load_scalar_or_matrix_scale ( weights , f"{ p } .weight_scale" , shape )
153
+ for p , shape in zip ( prefixes , shapes )
143
154
]
144
155
scale = torch .cat (scale , dim = 0 ).reshape (- 1 )
156
+
145
157
return Fp8Weight (
146
158
weight = w ,
147
159
weight_scale = scale ,
@@ -157,9 +169,11 @@ def get_weights_row(self, weights: "Weights", prefix: str):
157
169
w = weights .get_sharded (f"{ prefix } .weight" , dim = 1 )
158
170
# FP8 branch
159
171
if w .dtype == torch .float8_e4m3fn :
160
- scale = weights .get_tensor (
161
- f"{ prefix } .weight_scale" , to_dtype = False
162
- ).reshape (- 1 )
172
+ scale = (
173
+ weights .get_tensor (f"{ prefix } .weight_scale" , to_dtype = False )
174
+ .reshape (- 1 )
175
+ .expand (w .shape [0 ])
176
+ )
163
177
return Fp8Weight (
164
178
weight = w ,
165
179
weight_scale = scale ,
@@ -182,6 +196,9 @@ class Fp8Weight(Weight):
182
196
def get_linear (self , bias : torch .Tensor ):
183
197
if self .weight_scale is None :
184
198
return get_fp8_linear ().from_unquant (self .weight , bias , self .dtype )
199
+ # This is not checked by the fbgemm kernels, but they require contiguous
200
+ # memory. Can be non-contiguous when we e.g. expand from scalars.
201
+ self .weight_scale = self .weight_scale .contiguous ()
185
202
return get_fp8_linear ().from_fp8 (
186
203
self .weight , self .weight_scale , self .activation_scale_ub , bias , self .dtype
187
204
)
@@ -222,6 +239,9 @@ def from_unquant(cls, weight, bias, dtype):
222
239
223
240
@classmethod
224
241
def from_fp8 (cls , weight , scale , input_scale , bias , dtype ):
242
+ if FBGEMM_DYN_AVAILABLE :
243
+ # fbgemm needs float32 scales.
244
+ scale = scale .float ()
225
245
return cls (
226
246
qweight = weight ,
227
247
scale = scale ,
@@ -256,3 +276,10 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
256
276
bias = self .bias ,
257
277
)
258
278
return output
279
+
280
+
281
+ def _load_scalar_or_matrix_scale (weights : Weights , prefix : str , shape : torch .Size ):
282
+ scale = weights .get_tensor (prefix , to_dtype = False )
283
+ if scale .numel () > 1 :
284
+ scale = weights .get_sharded (prefix , dim = 0 , to_dtype = False )
285
+ return scale .reshape (- 1 ).expand (shape [0 ])
0 commit comments