|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | """Weight-Only utility."""
|
| 15 | +import numpy as np |
15 | 16 | import torch
|
16 | 17 |
|
17 | 18 | from neural_compressor.torch.utils import accelerator, device_synchronize, logger
|
@@ -1228,3 +1229,221 @@ def convert_dtype_str2torch(str_dtype):
|
1228 | 1229 | return torch.bfloat16
|
1229 | 1230 | else:
|
1230 | 1231 | assert False, "Unsupported str dtype {} to torch dtype".format(str_dtype)
|
| 1232 | + |
| 1233 | + |
| 1234 | +# ref reverse reorder from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L491 |
| 1235 | +def awq_reverse_reorder_int_tensor(int_tensor, bits: int): |
| 1236 | + """Awq tensor convert tool. |
| 1237 | +
|
| 1238 | + Reverse_reorder_int_tensor |
| 1239 | + """ |
| 1240 | + assert bits == 4 |
| 1241 | + |
| 1242 | + int_tensor = int_tensor.T.contiguous() |
| 1243 | + compress_ratio = 32 // bits |
| 1244 | + assert int_tensor.shape[-1] % compress_ratio == 0 |
| 1245 | + |
| 1246 | + order_map = [0, 2, 4, 6, 1, 3, 5, 7] |
| 1247 | + order_tensor = torch.tensor(order_map, dtype=torch.int32, device=int_tensor.device).reshape(1, -1) |
| 1248 | + order_tensor = order_tensor.repeat(int_tensor.shape[1] // compress_ratio, 1) |
| 1249 | + order_tensor = order_tensor + torch.arange( |
| 1250 | + 0, |
| 1251 | + int_tensor.shape[1], |
| 1252 | + compress_ratio, |
| 1253 | + dtype=torch.int32, |
| 1254 | + device=int_tensor.device, |
| 1255 | + ).reshape(-1, 1) |
| 1256 | + order_tensor = order_tensor.reshape(-1) |
| 1257 | + |
| 1258 | + reverse_order_tensor = torch.arange(order_tensor.shape[0])[order_tensor] |
| 1259 | + reverse_order_tensor = reverse_order_tensor[order_tensor] |
| 1260 | + int_tensor = int_tensor[:, reverse_order_tensor] |
| 1261 | + return int_tensor |
| 1262 | + |
| 1263 | + |
| 1264 | +# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516 |
| 1265 | +def unpack_awq( |
| 1266 | + awq_qweight: torch.Tensor, |
| 1267 | + awq_qzeros: torch.Tensor, |
| 1268 | + awq_scales: torch.Tensor, |
| 1269 | + bits: int, |
| 1270 | + group_size: int, |
| 1271 | +): |
| 1272 | + """Unpack awq format to actual values. |
| 1273 | +
|
| 1274 | + Args: |
| 1275 | + awq_qweight (`torch.LongTensor`): |
| 1276 | + Expected shape: (in_features, out_features // (32 // bits)) |
| 1277 | + awq_qzeros (`torch.LongTensor`): |
| 1278 | + Expected shape: (in_features // group_size, out_features // (32 // bits)) |
| 1279 | + awq_scales (`torch.LongTensor`): |
| 1280 | + Expected shape: (in_features // group_size, out_features) |
| 1281 | +
|
| 1282 | + Returns: |
| 1283 | + fp16_weight (`torch.LongTensor`): |
| 1284 | + With shape (in_features, out_features). |
| 1285 | + zeros (`torch.LongTensor`): |
| 1286 | + With shape (in_features // group_size, out_features). |
| 1287 | + """ |
| 1288 | + assert bits == 4 |
| 1289 | + |
| 1290 | + qzeros = awq_qzeros |
| 1291 | + qweight = awq_qweight |
| 1292 | + qweight = qweight.T.contiguous() |
| 1293 | + |
| 1294 | + infeatures = awq_qweight.shape[0] |
| 1295 | + |
| 1296 | + wf = torch.tensor(list(range(0, 32, bits)), dtype=torch.int32, device=qzeros.device).unsqueeze(0) |
| 1297 | + zeros = torch.bitwise_right_shift(torch.unsqueeze(qzeros, 2), wf.unsqueeze(0)).to( |
| 1298 | + torch.int16 if bits == 8 else torch.int8 |
| 1299 | + ) |
| 1300 | + |
| 1301 | + # zeros = zeros + 1 |
| 1302 | + |
| 1303 | + torch.bitwise_and(zeros, (2**bits) - 1, out=zeros) |
| 1304 | + |
| 1305 | + zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) |
| 1306 | + |
| 1307 | + weight = torch.bitwise_right_shift(torch.unsqueeze(qweight, 1), wf.unsqueeze(-1)).to( |
| 1308 | + torch.int16 if bits == 8 else torch.int8 |
| 1309 | + ) |
| 1310 | + torch.bitwise_and(weight, (2**bits) - 1, out=weight) |
| 1311 | + weight = weight.reshape(-1, group_size, weight.shape[2]) |
| 1312 | + |
| 1313 | + weight = weight.view(-1, weight.shape[-1]) |
| 1314 | + zeros = zeros.view(-1, zeros.shape[-1]) |
| 1315 | + |
| 1316 | + zeros = zeros.T.contiguous() |
| 1317 | + zeros = awq_reverse_reorder_int_tensor(zeros, bits) |
| 1318 | + weight = awq_reverse_reorder_int_tensor(weight, bits) |
| 1319 | + |
| 1320 | + # Dequantize weights. |
| 1321 | + scales = awq_scales |
| 1322 | + zeros = zeros.contiguous() |
| 1323 | + scale_zeros = zeros * scales |
| 1324 | + |
| 1325 | + g_idx = torch.tensor([i // group_size for i in range(infeatures)], dtype=torch.int32) |
| 1326 | + scale_mat = scales[g_idx] |
| 1327 | + scale_zeros_mat = scale_zeros[g_idx].half() |
| 1328 | + |
| 1329 | + qdq_weight_T = weight * scale_mat - scale_zeros_mat.half() |
| 1330 | + |
| 1331 | + fp16_weight = qdq_weight_T.T |
| 1332 | + |
| 1333 | + return fp16_weight, zeros |
| 1334 | + |
| 1335 | + |
| 1336 | +# ref weight unpack from AutoAWQ https://github.com/AutoGPTQ/AutoGPTQ/blob/v0.7.1/auto_gptq/modeling/_utils.py#L516 |
| 1337 | +def pack_from_tensors( |
| 1338 | + unpacked_qweight: torch.Tensor, |
| 1339 | + unpacked_qzeros: torch.Tensor, |
| 1340 | + awq_scales: torch.Tensor, |
| 1341 | + bits: int, |
| 1342 | + group_size: int, |
| 1343 | +): |
| 1344 | + """Pack the tensor to optimum format. |
| 1345 | +
|
| 1346 | + Args: |
| 1347 | + unpacked_qweight (`torch.LongTensor`): |
| 1348 | + Expected shape: (in_features, out_features) |
| 1349 | + unpacked_qzeros (`torch.LongTensor`): |
| 1350 | + Expected shape: (in_features // group_size, out_features) |
| 1351 | + awq_scales (`torch.LongTensor`): |
| 1352 | + Expected shape: (in_features // group_size, out_features) |
| 1353 | +
|
| 1354 | + Returns: |
| 1355 | + qweight (`torch.LongTensor`): |
| 1356 | + With shape (in_features // (32 // bits), out_features) |
| 1357 | + qzeros (`torch.LongTensor`): |
| 1358 | + With shape (in_features // group_size, out_features // (32 // bits)) |
| 1359 | + """ |
| 1360 | + assert bits == 4 |
| 1361 | + W = unpacked_qweight.clone().cpu() |
| 1362 | + |
| 1363 | + # TODO: This should be checked somehow. |
| 1364 | + # if isinstance(linear, nn.Conv2d): |
| 1365 | + # W = W.flatten(1) |
| 1366 | + # if isinstance(linear, transformers.pytorch_utils.Conv1D): |
| 1367 | + # W = W.t() |
| 1368 | + |
| 1369 | + awq_scales = awq_scales.t().contiguous() |
| 1370 | + unpacked_qzeros = unpacked_qzeros.contiguous() |
| 1371 | + unpacked_qzeros = unpacked_qzeros.cpu() |
| 1372 | + |
| 1373 | + awq_scales = awq_scales.cpu() |
| 1374 | + scale_zeros = unpacked_qzeros.t() * awq_scales |
| 1375 | + scales = awq_scales.clone() |
| 1376 | + |
| 1377 | + infeatures = unpacked_qweight.shape[1] |
| 1378 | + |
| 1379 | + intweight = [] |
| 1380 | + for idx in range(infeatures): |
| 1381 | + g_idx = idx // group_size |
| 1382 | + |
| 1383 | + intweight.append(torch.round((W[:, idx] + scale_zeros[:, g_idx]) / scales[:, g_idx]).to(torch.int)[:, None]) |
| 1384 | + intweight = torch.cat(intweight, dim=1) |
| 1385 | + intweight = intweight.t().contiguous() |
| 1386 | + intweight = intweight.numpy().astype(np.uint32) |
| 1387 | + |
| 1388 | + i = 0 |
| 1389 | + row = 0 |
| 1390 | + qweight = np.zeros((intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=np.uint32) |
| 1391 | + while row < qweight.shape[0]: |
| 1392 | + for j in range(i, i + (32 // bits)): |
| 1393 | + qweight[row] |= intweight[j] << (bits * (j - i)) |
| 1394 | + i += 32 // bits |
| 1395 | + row += 1 |
| 1396 | + |
| 1397 | + qweight = qweight.astype(np.int32) |
| 1398 | + qweight = torch.from_numpy(qweight) |
| 1399 | + |
| 1400 | + unpacked_qzeros = unpacked_qzeros - 1 |
| 1401 | + torch.bitwise_and(unpacked_qzeros, (2**bits) - 1, out=unpacked_qzeros) |
| 1402 | + |
| 1403 | + unpacked_qzeros = unpacked_qzeros.numpy().astype(np.uint32) |
| 1404 | + qzeros = np.zeros( |
| 1405 | + (unpacked_qzeros.shape[0], unpacked_qzeros.shape[1] // 32 * bits), |
| 1406 | + dtype=np.uint32, |
| 1407 | + ) |
| 1408 | + i = 0 |
| 1409 | + col = 0 |
| 1410 | + while col < qzeros.shape[1]: |
| 1411 | + for j in range(i, i + (32 // bits)): |
| 1412 | + qzeros[:, col] |= unpacked_qzeros[:, j] << (bits * (j - i)) |
| 1413 | + i += 32 // bits |
| 1414 | + col += 1 |
| 1415 | + |
| 1416 | + qzeros = qzeros.astype(np.int32) |
| 1417 | + qzeros = torch.from_numpy(qzeros) |
| 1418 | + |
| 1419 | + return qweight, qzeros |
| 1420 | + |
| 1421 | + |
| 1422 | +def repack_awq_to_optimum_format( |
| 1423 | + awq_qweight: torch.Tensor, |
| 1424 | + awq_qzeros: torch.Tensor, |
| 1425 | + awq_scales: torch.Tensor, |
| 1426 | + bits: int, |
| 1427 | + group_size: int, |
| 1428 | +): |
| 1429 | + """The function to repack_awq_to_optimum_format. |
| 1430 | +
|
| 1431 | + Args: |
| 1432 | + awq_qweight (`torch.LongTensor`): |
| 1433 | + Expected shape: (in_features, out_features // (32 // bits)) |
| 1434 | + awq_qzeros (`torch.LongTensor`): |
| 1435 | + Expected shape: (in_features // group_size, out_features // (32 // bits)) |
| 1436 | + awq_scales (`torch.LongTensor`): |
| 1437 | + Expected shape: (in_features // group_size, out_features) |
| 1438 | +
|
| 1439 | + Returns: |
| 1440 | + qweight (`torch.LongTensor`): |
| 1441 | + With shape (in_features // (32 // bits), out_features) |
| 1442 | + qzeros (`torch.LongTensor`): |
| 1443 | + With shape (in_features // group_size, out_features // (32 // bits)) |
| 1444 | + scales (`torch.LongTensor`): |
| 1445 | + Expected shape: (in_features // group_size, out_features) |
| 1446 | + """ |
| 1447 | + unpack_qweight, unpack_qzeros = unpack_awq(awq_qweight, awq_qzeros, awq_scales, bits, group_size) |
| 1448 | + qweight, qzeros = pack_from_tensors(unpack_qweight, unpack_qzeros, awq_scales, bits, group_size) |
| 1449 | + return qweight, qzeros, awq_scales |
0 commit comments