Skip to content

Commit ee600ba

Browse files
authored
add repack_awq_to_optimum_format function (#1998)
Signed-off-by: changwangss <chang1.wang@intel.com>
1 parent 4ee6861 commit ee600ba

File tree

5 files changed

+313
-25
lines changed

5 files changed

+313
-25
lines changed

neural_compressor/torch/algorithms/weight_only/utility.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Weight-Only utility."""
15+
import numpy as np
1516
import torch
1617

1718
from neural_compressor.torch.utils import accelerator, device_synchronize, logger
@@ -1228,3 +1229,221 @@ def convert_dtype_str2torch(str_dtype):
12281229
return torch.bfloat16
12291230
else:
12301231
assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype)
1232+
1233+
1234+
# ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491
1235+
def awq_reverse_reorder_int_tensor(int_tensor, bits: int):
1236+
"""Awq tensor convert tool.
1237+
1238+
Reverse_reorder_int_tensor
1239+
"""
1240+
assert bits == 4
1241+
1242+
int_tensor = int_tensor.T.contiguous()
1243+
compress_ratio = 32 // bits
1244+
assert int_tensor.shape[-1] % compress_ratio == 0
1245+
1246+
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
1247+
order_tensor = torch.tensor(order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1)
1248+
order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1)
1249+
order_tensor = order_tensor + torch.arange(
1250+
0,
1251+
int_tensor.shape[1],
1252+
compress_ratio,
1253+
dtype=torch.int32,
1254+
device=int_tensor.device,
1255+
).reshape(-1, 1)
1256+
order_tensor = order_tensor.reshape(-1)
1257+
1258+
reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor]
1259+
reverse_order_tensor = reverse_order_tensor[order_tensor]
1260+
int_tensor = int_tensor[:, reverse_order_tensor]
1261+
return int_tensor
1262+
1263+
1264+
# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
1265+
def unpack_awq(
1266+
awq_qweight: torch.Tensor,
1267+
awq_qzeros: torch.Tensor,
1268+
awq_scales: torch.Tensor,
1269+
bits: int,
1270+
group_size: int,
1271+
):
1272+
"""Unpack awq format to actual values.
1273+
1274+
Args:
1275+
awq_qweight (`torch.LongTensor`):
1276+
Expected shape: (in_features, out_features // (32 // bits))
1277+
awq_qzeros (`torch.LongTensor`):
1278+
Expected shape: (in_features // group_size, out_features // (32 // bits))
1279+
awq_scales (`torch.LongTensor`):
1280+
Expected shape: (in_features // group_size, out_features)
1281+
1282+
Returns:
1283+
fp16_weight (`torch.LongTensor`):
1284+
With shape (in_features, out_features).
1285+
zeros (`torch.LongTensor`):
1286+
With shape (in_features // group_size, out_features).
1287+
"""
1288+
assert bits == 4
1289+
1290+
qzeros = awq_qzeros
1291+
qweight = awq_qweight
1292+
qweight = qweight.T.contiguous()
1293+
1294+
infeatures = awq_qweight.shape[0]
1295+
1296+
wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0)
1297+
zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to(
1298+
torch.int16 if bits == 8 else torch.int8
1299+
)
1300+
1301+
# zeros = zeros + 1
1302+
1303+
torch.bitwise_and(zeros, (2**bits) - 1, out=zeros)
1304+
1305+
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
1306+
1307+
weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1), wf.unsqueeze(-1)).to(
1308+
torch.int16 if bits == 8 else torch.int8
1309+
)
1310+
torch.bitwise_and(weight, (2**bits) - 1, out=weight)
1311+
weight = weight.reshape(-1, group_size, weight.shape[2])
1312+
1313+
weight = weight.view(-1, weight.shape[-1])
1314+
zeros = zeros.view(-1, zeros.shape[-1])
1315+
1316+
zeros = zeros.T.contiguous()
1317+
zeros = awq_reverse_reorder_int_tensor(zeros, bits)
1318+
weight = awq_reverse_reorder_int_tensor(weight, bits)
1319+
1320+
# Dequantize weights.
1321+
scales = awq_scales
1322+
zeros = zeros.contiguous()
1323+
scale_zeros = zeros * scales
1324+
1325+
g_idx = torch.tensor([i // group_size for i in range(infeatures)], dtype=torch.int32)
1326+
scale_mat = scales[g_idx]
1327+
scale_zeros_mat = scale_zeros[g_idx].half()
1328+
1329+
qdq_weight_T = weight * scale_mat - scale_zeros_mat.half()
1330+
1331+
fp16_weight = qdq_weight_T.T
1332+
1333+
return fp16_weight, zeros
1334+
1335+
1336+
# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516
1337+
def pack_from_tensors(
1338+
unpacked_qweight: torch.Tensor,
1339+
unpacked_qzeros: torch.Tensor,
1340+
awq_scales: torch.Tensor,
1341+
bits: int,
1342+
group_size: int,
1343+
):
1344+
"""Pack the tensor to optimum format.
1345+
1346+
Args:
1347+
unpacked_qweight (`torch.LongTensor`):
1348+
Expected shape: (in_features, out_features)
1349+
unpacked_qzeros (`torch.LongTensor`):
1350+
Expected shape: (in_features // group_size, out_features)
1351+
awq_scales (`torch.LongTensor`):
1352+
Expected shape: (in_features // group_size, out_features)
1353+
1354+
Returns:
1355+
qweight (`torch.LongTensor`):
1356+
With shape (in_features // (32 // bits), out_features)
1357+
qzeros (`torch.LongTensor`):
1358+
With shape (in_features // group_size, out_features // (32 // bits))
1359+
"""
1360+
assert bits == 4
1361+
W = unpacked_qweight.clone().cpu()
1362+
1363+
# TODO: This should be checked somehow.
1364+
# if isinstance(linear, nn.Conv2d):
1365+
# W = W.flatten(1)
1366+
# if isinstance(linear, transformers.pytorch_utils.Conv1D):
1367+
# W = W.t()
1368+
1369+
awq_scales = awq_scales.t().contiguous()
1370+
unpacked_qzeros = unpacked_qzeros.contiguous()
1371+
unpacked_qzeros = unpacked_qzeros.cpu()
1372+
1373+
awq_scales = awq_scales.cpu()
1374+
scale_zeros = unpacked_qzeros.t() * awq_scales
1375+
scales = awq_scales.clone()
1376+
1377+
infeatures = unpacked_qweight.shape[1]
1378+
1379+
intweight = []
1380+
for idx in range(infeatures):
1381+
g_idx = idx // group_size
1382+
1383+
intweight.append(torch.round((W[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None])
1384+
intweight = torch.cat(intweight, dim=1)
1385+
intweight = intweight.t().contiguous()
1386+
intweight = intweight.numpy().astype(np.uint32)
1387+
1388+
i = 0
1389+
row = 0
1390+
qweight = np.zeros((intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=np.uint32)
1391+
while row < qweight.shape[0]:
1392+
for j in range(i, i + (32 // bits)):
1393+
qweight[row] |= intweight[j] << (bits * (j - i))
1394+
i += 32 // bits
1395+
row += 1
1396+
1397+
qweight = qweight.astype(np.int32)
1398+
qweight = torch.from_numpy(qweight)
1399+
1400+
unpacked_qzeros = unpacked_qzeros - 1
1401+
torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros)
1402+
1403+
unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32)
1404+
qzeros = np.zeros(
1405+
(unpacked_qzeros.shape[0], unpacked_qzeros.shape[1] // 32 * bits),
1406+
dtype=np.uint32,
1407+
)
1408+
i = 0
1409+
col = 0
1410+
while col < qzeros.shape[1]:
1411+
for j in range(i, i + (32 // bits)):
1412+
qzeros[:, col] |= unpacked_qzeros[:, j] << (bits * (j - i))
1413+
i += 32 // bits
1414+
col += 1
1415+
1416+
qzeros = qzeros.astype(np.int32)
1417+
qzeros = torch.from_numpy(qzeros)
1418+
1419+
return qweight, qzeros
1420+
1421+
1422+
def repack_awq_to_optimum_format(
1423+
awq_qweight: torch.Tensor,
1424+
awq_qzeros: torch.Tensor,
1425+
awq_scales: torch.Tensor,
1426+
bits: int,
1427+
group_size: int,
1428+
):
1429+
"""The function to repack_awq_to_optimum_format.
1430+
1431+
Args:
1432+
awq_qweight (`torch.LongTensor`):
1433+
Expected shape: (in_features, out_features // (32 // bits))
1434+
awq_qzeros (`torch.LongTensor`):
1435+
Expected shape: (in_features // group_size, out_features // (32 // bits))
1436+
awq_scales (`torch.LongTensor`):
1437+
Expected shape: (in_features // group_size, out_features)
1438+
1439+
Returns:
1440+
qweight (`torch.LongTensor`):
1441+
With shape (in_features // (32 // bits), out_features)
1442+
qzeros (`torch.LongTensor`):
1443+
With shape (in_features // group_size, out_features // (32 // bits))
1444+
scales (`torch.LongTensor`):
1445+
Expected shape: (in_features // group_size, out_features)
1446+
"""
1447+
unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size)
1448+
qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size)
1449+
return qweight, qzeros, awq_scales

neural_compressor/transformers/models/modeling_auto.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@
4747
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
4848
from neural_compressor.torch.utils import set_module
4949

50-
from ..quantization.utils import convert_dtype_torch2str, convert_to_quantized_model, replace_linear, save_low_bit
50+
from ..quantization.utils import (
51+
convert_dtype_torch2str,
52+
convert_to_quantized_model,
53+
repack_awq_and_load_state_dict,
54+
replace_linear,
55+
save_low_bit,
56+
)
5157
from ..utils import AutoRoundConfig, AwqConfig, GPTQConfig, RtnConfig, TeqConfig
5258

5359

@@ -179,6 +185,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
179185
) and model.config.model_type == "chatglm":
180186
model = model.float()
181187
model = convert_to_quantized_model(model, quantization_config, device=device_map)
188+
if isinstance(quantization_config, AwqConfig):
189+
quantization_config.backend = "inc"
182190
quantization_config.remove_redundant_parameters()
183191
model.config.quantization_config = quantization_config
184192
else:
@@ -295,6 +303,7 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
295303
quantization_config = GPTQConfig.from_dict(quantization_config)
296304
elif quantization_config["quant_method"] == "autoround":
297305
quantization_config = AutoRoundConfig.from_dict(quantization_config)
306+
298307
assert quantization_config is not None, "Detect this model is not a low-bit model."
299308

300309
if commit_hash is None:
@@ -613,41 +622,48 @@ def load_low_bit(cls, pretrained_model_name_or_path, *model_args, **kwargs):
613622

614623
with ContextManagers(init_contexts):
615624
model = model_class(config, *model_args, **kwargs)
616-
625+
if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
626+
if quantization_config.modules_to_not_convert is None:
627+
quantization_config.modules_to_not_convert = ["lm_head", "transformer.output_layer", "embed_out"]
628+
else:
629+
quantization_config.modules_to_not_convert += ["lm_head", "transformer.output_layer", "embed_out"]
617630
model = build_woq_model(model, quantization_config)
618631

619632
if is_sharded:
620633
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
621634
else:
622-
# Time to load the checkpoint
623635
state_dict = load_state_dict(resolved_archive_file)
624636
loaded_state_dict_keys = list(state_dict.keys())
625-
626637
# restore default dtype
627638
if dtype_orig is not None:
628639
torch.set_default_dtype(dtype_orig)
629640

630-
(
631-
model,
632-
missing_keys,
633-
unexpected_keys,
634-
mismatched_keys,
635-
offload_index,
636-
error_msgs,
637-
) = model_class._load_pretrained_model(
638-
model,
639-
None,
640-
loaded_state_dict_keys, # XXX: rename?
641-
resolved_archive_file,
642-
pretrained_model_name_or_path,
643-
sharded_metadata=sharded_metadata,
644-
_fast_init=_fast_init,
645-
low_cpu_mem_usage=True,
646-
offload_folder=offload_folder,
647-
offload_state_dict=offload_state_dict,
648-
dtype=torch_dtype,
649-
keep_in_fp32_modules=[],
650-
)
641+
if quantization_config.quant_method.value == "awq" and quantization_config.backend != "inc":
642+
model = repack_awq_and_load_state_dict(
643+
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
644+
)
645+
else:
646+
(
647+
model,
648+
missing_keys,
649+
unexpected_keys,
650+
mismatched_keys,
651+
offload_index,
652+
error_msgs,
653+
) = model_class._load_pretrained_model(
654+
model,
655+
None,
656+
loaded_state_dict_keys, # XXX: rename?
657+
resolved_archive_file,
658+
pretrained_model_name_or_path,
659+
sharded_metadata=sharded_metadata,
660+
_fast_init=_fast_init,
661+
low_cpu_mem_usage=True,
662+
offload_folder=offload_folder,
663+
offload_state_dict=offload_state_dict,
664+
dtype=torch_dtype,
665+
keep_in_fp32_modules=[],
666+
)
651667

652668
# make sure token embedding weights are still tied if needed
653669
model.tie_weights()

neural_compressor/transformers/quantization/utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from neural_compressor.common.utils import LazyImport, logger
2525
from neural_compressor.torch.algorithms.weight_only.modules import INCWeightOnlyLinear
26+
from neural_compressor.torch.algorithms.weight_only.utility import repack_awq_to_optimum_format
2627
from neural_compressor.torch.quantization import (
2728
AutoRoundConfig,
2829
AWQConfig,
@@ -654,3 +655,40 @@ def save_low_bit(self, save_directory: Union[str, os.PathLike], push_to_hub: boo
654655
token=kwargs.get("token"),
655656
)
656657
self.quantization_config.save_pretrained(save_directory, **kwargs)
658+
659+
660+
def repack_awq_and_load_state_dict(
661+
model, resolved_archive_file, loaded_state_dict_keys, quantization_config, is_sharded
662+
):
663+
from transformers.modeling_utils import load_state_dict
664+
665+
bits = quantization_config.bits
666+
group_size = quantization_config.group_size
667+
668+
state_dict = {}
669+
if isinstance(resolved_archive_file, str):
670+
resolved_archive_file = [resolved_archive_file]
671+
assert isinstance(resolved_archive_file, list), "Please check if the loading weight is shared."
672+
for shard_file in resolved_archive_file:
673+
assert shard_file.endswith("safetensors"), "Please check the loading weight saved format."
674+
state_dict.update(load_state_dict(shard_file))
675+
assert len(state_dict.keys()) > 0, "Please check the state_dict loading."
676+
for name, module in model.named_modules():
677+
if isinstance(module, INCWeightOnlyLinear):
678+
assert name + ".qweight" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.qweight'}"
679+
assert name + ".qzeros" in loaded_state_dict_keys, f"Please check the state_dict key {name + '.qzeros'}"
680+
assert name + ".scales" in loaded_state_dict_keys, f"Please check the state_dict key { name + '.scales'}"
681+
if name + ".scales" in loaded_state_dict_keys:
682+
awq_qweight = state_dict[name + ".qweight"]
683+
awq_qzeros = state_dict[name + ".qzeros"]
684+
awq_scales = state_dict[name + ".scales"]
685+
qweight, qzeros, awq_scales = repack_awq_to_optimum_format(
686+
awq_qweight, awq_qzeros, awq_scales, bits, group_size
687+
)
688+
state_dict[name + ".qweight"] = qweight
689+
state_dict[name + ".qzeros"] = qzeros
690+
state_dict[name + ".scales"] = awq_scales
691+
692+
model.load_state_dict(state_dict, strict=False, assign=True)
693+
694+
return model

neural_compressor/transformers/utils/quantization_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(
409409
zero_point: bool = True,
410410
absorb_layer_dict: dict = {},
411411
quant_lm_head: bool = False,
412+
backend: str = None,
412413
**kwargs,
413414
):
414415
self.quant_method = QuantizationMethod.AWQ
@@ -427,6 +428,7 @@ def __init__(
427428
self.seq_len = seq_len
428429
self.absorb_layer_dict = absorb_layer_dict
429430
self.quant_lm_head = quant_lm_head
431+
self.backend = backend
430432
self.modules_to_not_convert = kwargs.get(
431433
"modules_to_not_convert", ["lm_head", "transformer.output_layer", "embed_out"]
432434
)

0 commit comments

Comments
 (0)