2323from torch .autograd import Function
2424from torch .nn import functional as F
2525
26- from neural_compressor .torch .utils import logger
26+ from neural_compressor .torch .utils import accelerator , logger
2727
2828from .utility import quant_tensor
2929
@@ -174,9 +174,9 @@ def __init__(
174174
175175 def pack (self , int_weight , scale , zp , bias , g_idx = None ):
176176 if self .use_optimum_format :
177- self .scales = self .scales .t_ () .contiguous ()
178- self .qweight = self .qweight .t_ () .contiguous ()
179- self .qzeros = self .qzeros .t_ () .contiguous ()
177+ self .scales = self .scales .T .contiguous ()
178+ self .qweight = self .qweight .T .contiguous ()
179+ self .qzeros = self .qzeros .T .contiguous ()
180180 int_weight = int_weight .to (self .device )
181181 if self .use_optimum_format and zp is None :
182182 # to avoid overflow
@@ -197,124 +197,111 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
197197 assert scale .shape == self .scales .shape , f"{ scale .shape } != { self .scales .shape } Scale shape is mismatched."
198198 self .scales = scale .type (self .float_type ).to (self .device )
199199 if not self .use_optimum_format and self .compression_dim == 0 :
200- int_weight = int_weight .t_ () .contiguous ()
201- self .qweight = self .qweight .t_ () .contiguous ()
200+ int_weight = int_weight .T .contiguous ()
201+ self .qweight = self .qweight .T .contiguous ()
202202 origin_shape = int_weight .shape
203203 target_shape = self .qweight .shape
204204 assert origin_shape [0 ] == target_shape [0 ], "output channels mismatch, please check."
205- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
206205
207206 # pack weight
208- for j in range (target_shape [1 ]):
209- start = self .n_pack * j
210- end = self .n_pack * (j + 1 )
211- tmp = int_weight [:, start :end ].type (self .compression_dtype )
212- for e in range (tmp .shape [1 ]):
213- tmp [:, e ] &= mask
214- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
215- self .qweight [:, j ] |= tmp [:, e ]
207+ self .qweight .copy_ (self .pack_tensor (int_weight ))
216208 if not self .use_optimum_format and self .compression_dim == 0 :
217- self .qweight = self .qweight .t_ () .contiguous ()
209+ self .qweight = self .qweight .T .contiguous ()
218210
219211 if zp is not None :
220212 zp = zp .to (self .device )
221213 if self .use_optimum_format :
222214 zp -= 1
223215 if self .use_optimum_format or self .compression_dim == 0 :
224- zp = zp .t_ () .contiguous ()
225- self .qzeros = self .qzeros .t_ () .contiguous ()
216+ zp = zp .T .contiguous ()
217+ self .qzeros = self .qzeros .T .contiguous ()
226218 assert hasattr (self , "qzeros" ), "zp is not set when initializing."
227- target_shape = self .qzeros .shape
228- for j in range (target_shape [1 ]):
229- start = self .n_pack * j
230- end = self .n_pack * (j + 1 )
231- tmp = zp [:, start :end ].type (self .compression_dtype )
232- for e in range (tmp .shape [1 ]):
233- tmp [:, e ] &= mask
234- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
235- self .qzeros [:, j ] |= tmp [:, e ]
219+ self .qzeros .copy_ (self .pack_tensor (zp ))
236220 if self .use_optimum_format or self .compression_dim == 0 :
237- self .qzeros = self .qzeros .t_ () .contiguous ()
221+ self .qzeros = self .qzeros .T .contiguous ()
238222 if self .use_optimum_format :
239- self .scales = self .scales .t_ () .contiguous ()
240- self .qweight = self .qweight .t_ () .contiguous ()
241- self .qzeros = self .qzeros .t_ () .contiguous ()
223+ self .scales = self .scales .T .contiguous ()
224+ self .qweight = self .qweight .T .contiguous ()
225+ self .qzeros = self .qzeros .T .contiguous ()
242226
243227 def recover (self ):
244228 logger .debug (f"Recovering { self } weight" )
245- scales = self .scales .t_ () .contiguous () if self .use_optimum_format else self .scales
246- qweight = self .qweight .t_ () .contiguous () if self .use_optimum_format else self .qweight
229+ scales = self .scales .T .contiguous () if self .use_optimum_format else self .scales
230+ qweight = self .qweight .T .contiguous () if self .use_optimum_format else self .qweight
247231
248232 device = scales .device
249233 fp32_weight = torch .zeros (self .out_features , self .in_features , dtype = self .float_type ).to (device )
250234 if self .g_idx is None :
251235 # used for recovering fp32_weight
252236 self .g_idx = torch .tensor ([i // self .group_size for i in range (self .in_features )], dtype = torch .int32 )
253- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (device )
254- if hasattr (self , "qzeros" ):
255- weight_dtype = torch .uint8
256- else :
257- weight_dtype = torch .int8
258237 # unpack weight
259- weight = torch .zeros (self .out_features , self .in_features , dtype = weight_dtype ).to (device )
260238 if not self .use_optimum_format and self .compression_dim == 0 :
261- weight = weight .t_ ().contiguous ()
262- qweight = qweight .t_ ().contiguous ()
263- origin_shape = weight .shape
264- target_shape = qweight .shape
265- for j in range (target_shape [1 ]):
266- for e in range (self .n_pack ):
267- index = j * self .n_pack + e
268- if index >= origin_shape [1 ]:
269- continue
270- tmp = qweight [:, j ]
271- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
272- tmp = tmp >> self .compress_bits - self .bits
273- if weight_dtype == torch .uint8 :
274- tmp &= mask # remove sign bit
275- weight [:, index ] = tmp .type (weight_dtype )
239+ qweight = qweight .T .contiguous ()
240+ weight = self .unpack_tensor (qweight )
276241 if not self .use_optimum_format and self .compression_dim == 0 :
277- weight = weight .t_ ().contiguous ()
242+ weight = weight .T .contiguous ()
243+ weight = weight [: self .out_features , : self .in_features ] # avoid oversize
278244 if "int" not in self .dtype :
279245 new_weight = torch .zeros (self .out_features , self .in_features ).to (device )
280246 for k , v in self .int2float_mapping .items ():
281247 new_weight += torch .where (weight == k , v , 0 )
282248 weight = new_weight
283249 # unpack zero_point
284250 if hasattr (self , "qzeros" ):
285- zp_dtype = self .compression_dtype # to avoid overflow when weight-zp
286- zp = torch .zeros (scales .shape , dtype = zp_dtype ).to (device )
287- qzeros = self .qzeros .t_ ().contiguous () if self .use_optimum_format else self .qzeros
251+ qzeros = self .qzeros .T .contiguous () if self .use_optimum_format else self .qzeros
288252 if self .use_optimum_format or self .compression_dim == 0 :
289- zp = zp .t_ ().contiguous ()
290- qzeros = qzeros .t_ ().contiguous ()
291- origin_shape = zp .shape
292- target_shape = qzeros .shape
293- for j in range (target_shape [1 ]):
294- for e in range (self .n_pack ):
295- index = j * self .n_pack + e
296- if index >= origin_shape [1 ]:
297- continue
298- tmp = qzeros [:, j ]
299- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
300- tmp = tmp >> self .compress_bits - self .bits
301- tmp &= mask
302- zp [:, index ] = tmp .type (zp_dtype )
253+ qzeros = qzeros .T .contiguous ()
254+ zp = self .unpack_tensor (qzeros )
303255 if self .use_optimum_format or self .compression_dim == 0 :
304- zp = zp .t_ ().contiguous ()
256+ zp = zp .T .contiguous ()
257+ zp = zp [: scales .shape [0 ], : scales .shape [1 ]] # avoid oversize
305258 if self .use_optimum_format :
306259 # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
307260 zp += 1
308261 zp = torch .where (zp > (2 ** self .bits - 1 ), 0 , zp )
309262 # recover fp32 weight with int_weight, scale, and zero_point
310263 for idx in range (self .in_features ):
311- fp32_weight [:, idx ] = (weight [:, idx ] - zp [:, self .g_idx [idx ]]) * scales [:, self .g_idx [idx ]]
264+ fp32_weight [:, idx ] = (torch .subtract (weight [:, idx ], zp [:, self .g_idx [idx ]]).to (torch .int8 )) * scales [
265+ :, self .g_idx [idx ]
266+ ]
312267 else :
313268 # recover fp32 weight with int_weight, scale
314269 for idx in range (self .in_features ):
315270 fp32_weight [:, idx ] = weight [:, idx ] * scales [:, self .g_idx [idx ]]
316271 return fp32_weight
317272
273+ def pack_tensor (self , raw_tensor ):
274+ target_len = math .ceil (raw_tensor .shape [1 ] / self .n_pack )
275+ packed_tensor = torch .zeros (raw_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (self .device )
276+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
277+ for j in range (packed_tensor .shape [1 ]):
278+ start = self .n_pack * j
279+ end = self .n_pack * (j + 1 )
280+ tmp = raw_tensor [:, start :end ].type (self .compression_dtype )
281+ tmp &= mask
282+ for e in range (tmp .shape [1 ]):
283+ tmp [:, e ] = tmp [:, e ] << (self .bits * e )
284+ packed_tensor [:, j ] |= tmp [:, e ]
285+ accelerator .synchronize ()
286+ return packed_tensor
287+
288+ def unpack_tensor (self , packed_tensor ):
289+ target_dtype = torch .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else torch .uint8
290+ target_len = packed_tensor .shape [1 ] * self .n_pack
291+ unpacked_tensor = torch .zeros (packed_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (self .device )
292+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
293+ for j in range (packed_tensor .shape [1 ]):
294+ for e in range (self .n_pack ):
295+ index = j * self .n_pack + e
296+ tmp = packed_tensor [:, j ]
297+ tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
298+ tmp = tmp >> self .compress_bits - self .bits
299+ if target_dtype == torch .uint8 :
300+ tmp &= mask # remove sign bit
301+ unpacked_tensor [:, index ].copy_ (tmp .type (target_dtype ))
302+ accelerator .synchronize ()
303+ return unpacked_tensor
304+
318305 def forward (self , input ):
319306 if not hasattr (self , "weight" ):
320307 weight = self .recover ()
0 commit comments