Skip to content

Commit 1ff170b

Browse files
committed
Update int4 weight with serialized format
1 parent 091515a commit 1ff170b

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

generate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7777
callback(new_tokens[-1])
7878
new_probs.append(next_prob.clone())
7979
cur_token = next_token.view(1, -1)
80-
8180
return new_tokens, new_probs
8281

8382

@@ -241,6 +240,13 @@ def _load_model(checkpoint_path, device, precision, use_tp):
241240
apply_tp(model)
242241

243242
model = model.to(device=device, dtype=precision)
243+
if "int4" in str(checkpoint_path):
244+
from quantize import WeightOnlyInt4Linear
245+
for fqn, mod in model.named_modules():
246+
if isinstance(mod, WeightOnlyInt4Linear):
247+
weight = mod.weight.data
248+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
249+
mod.weight = weight_int4pack
244250
return model.eval()
245251

246252
def _get_model_size(model):

quantize.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128)
124124
.to(torch.int32)
125125
.reshape_as(w)
126126
)
127-
128-
return w_int32
127+
w_uint8 = (w_int32[::,::2] << 4 | w_int32[::,1::2]).to(torch.uint8)
128+
return w_uint8
129129

130130

131131
def group_quantize_tensor(w, n_bit=4, groupsize=128):
@@ -357,10 +357,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
357357
##### weight only int4 per channel groupwise quantized code ######
358358

359359
def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
360-
weight_int32, scales_and_zeros = group_quantize_tensor(
360+
weight_int4pack, scales_and_zeros = group_quantize_tensor(
361361
weight_bf16, n_bit=4, groupsize=groupsize
362362
)
363-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
364363
return weight_int4pack, scales_and_zeros
365364

366365

@@ -404,7 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
404403

405404
@torch.no_grad()
406405
def create_quantized_state_dict(self, use_cuda = True):
407-
if use_cuda:
406+
if use_cuda and torch.cuda.is_available():
408407
device="cuda"
409408
else:
410409
device="cpu"
@@ -507,7 +506,7 @@ def __init__(
507506
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508507
self.register_buffer(
509508
"weight",
510-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
509+
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
511510
)
512511
self.register_buffer(
513512
"scales_and_zeros",

0 commit comments

Comments
 (0)