@@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128)
124
124
.to (torch .int32 )
125
125
.reshape_as (w )
126
126
)
127
-
128
- return w_int32
127
+ w_uint8 = ( w_int32 [::,:: 2 ] << 4 | w_int32 [::, 1 :: 2 ]). to ( torch . uint8 )
128
+ return w_uint8
129
129
130
130
131
131
def group_quantize_tensor (w , n_bit = 4 , groupsize = 128 ):
@@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
357
357
##### weight only int4 per channel groupwise quantized code ######
358
358
359
359
def prepare_int4_weight_and_scales_and_zeros (weight_bf16 , groupsize , inner_k_tiles ):
360
- weight_int32 , scales_and_zeros = group_quantize_tensor (
360
+ weight_int4pack , scales_and_zeros = group_quantize_tensor (
361
361
weight_bf16 , n_bit = 4 , groupsize = groupsize
362
362
)
363
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
364
363
return weight_int4pack , scales_and_zeros
365
364
366
365
@@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
404
403
405
404
@torch .no_grad ()
406
405
def create_quantized_state_dict (self , use_cuda = True ):
407
- if use_cuda :
406
+ if use_cuda and torch . cuda . is_available () :
408
407
device = "cuda"
409
408
else :
410
409
device = "cpu"
@@ -507,7 +506,7 @@ def __init__(
507
506
assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
508
507
self .register_buffer (
509
508
"weight" ,
510
- torch .empty ((out_features // 8 , in_features // ( inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509
+ torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
511
510
)
512
511
self .register_buffer (
513
512
"scales_and_zeros" ,
0 commit comments