Skip to content

Commit 8ce1b8b

Browse files
committed
Enable Int4WeightOnlyGPTQQuantizer on Intel GPU.
1 parent 66eb801 commit 8ce1b8b

File tree

2 files changed

+180
-40
lines changed

2 files changed

+180
-40
lines changed

torchao/quantization/GPTQ.py

Lines changed: 154 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .quant_primitives import (
2828
MappingType,
2929
dequantize_affine,
30+
ZeroPointDomain,
3031
)
3132
from .unified import Quantizer
3233
from .utils import (
@@ -38,6 +39,7 @@
3839
groupwise_affine_quantize_tensor,
3940
groupwise_affine_quantize_tensor_from_qparams,
4041
pack_tinygemm_scales_and_zeros,
42+
align_tinygemm_scales_and_zeros,
4143
per_token_dynamic_quant,
4244
)
4345

@@ -75,18 +77,19 @@ def __init__(
7577
percdamp=0.01,
7678
groupsize=128,
7779
):
80+
self.device = self.get_device(model, inputs)
7881
self.id_to_name = {
7982
id(value): name for name, value in dict(model.named_parameters()).items()
8083
}
8184

8285
# trace model for one input
83-
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
86+
one_input = [multi.values[0] for multi in inputs] # pyre-ignore[16]
8487
# needed for GPTQ on the torchao llama model
8588
import torchao
8689

8790
torchao._models.llama.model.use_index_put_for_kv_cache = True
8891
exported_model = torch._dynamo.export(
89-
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
92+
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
9093
)(*one_input)
9194
super().__init__(exported_model.graph_module)
9295

@@ -100,6 +103,19 @@ def __init__(
100103
self.inputs = inputs
101104
self.gptq_done = False
102105
self.debug = False
106+
107+
108+
def get_device(self, model, inputs: _MultiInput):
109+
for name, param in model.named_parameters():
110+
if isinstance(param, torch.Tensor):
111+
return param.device
112+
113+
for multi in inputs:
114+
if isinstance(multi.values[0], torch.Tensor):
115+
return multi.values[0].device
116+
117+
return torch.device("cpu")
118+
103119

104120
def configure_quantization_mode(
105121
self,
@@ -163,16 +179,16 @@ def get_quantized_state_dict(self):
163179
return quantized_state_dict
164180

165181
def call_function(self, target, args, kwargs, already_quantized=False): # noqa: C901
166-
def tensors_to_cuda(args):
182+
def tensors_to_device(args):
167183
new_args = []
168184
for x in args:
169-
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
185+
new_args.append(x.to(self.device) if isinstance(x, torch.Tensor) else x)
170186
return new_args
171187

172188
# flatten args and kwargs together
173189
flat_args, spec = tree_flatten((args, kwargs))
174190
# move all single tensors to cuda, will move _MultiInputs to cuda one at a time
175-
flat_args = tensors_to_cuda(flat_args)
191+
flat_args = tensors_to_device(flat_args)
176192

177193
has_multi_input = _MultiInput in [type(x) for x in flat_args]
178194
if has_multi_input:
@@ -212,7 +228,7 @@ def tensors_to_cuda(args):
212228
total_batches = 0
213229

214230
for inp in transposed_args:
215-
inp = tensors_to_cuda(inp)
231+
inp = tensors_to_device(inp)
216232
cur_args, cur_kwargs = tree_unflatten(inp, spec)
217233

218234
if quantize_linear: # calculate H instead of output (will run the linear eventually with updated weight)
@@ -283,7 +299,7 @@ def SQNR(x, y):
283299
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
284300
) # matches
285301
print(
286-
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
302+
"SQNR for weight (can be low)", SQNR(W, DQ.to(self.device))
287303
) # fine to not match
288304
print(
289305
"SQNR for output with GPTQ (hopefully 35+)",
@@ -385,7 +401,12 @@ def faster_quant(self, H, W):
385401

386402
W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:])
387403

388-
torch.cuda.synchronize()
404+
if 'cuda' in self.device.type:
405+
torch.cuda.synchronize()
406+
elif 'xpu' in self.device.type:
407+
torch.xpu.synchronize()
408+
else:
409+
pass
389410

390411
if all_qparams == []:
391412
all_qparams.append(cur_qparams)
@@ -561,6 +582,30 @@ def linear_forward_int4(
561582
return c
562583

563584

585+
def linear_forward_int4_zero_domain(
586+
x: torch.Tensor,
587+
weight_int4pack: torch.Tensor,
588+
scales: torch.Tensor,
589+
zeros: torch.Tensor,
590+
out_features: int,
591+
groupsize: int,
592+
precision: torch.dtype = torch.bfloat16,
593+
scales_precision: torch.dtype = torch.bfloat16,
594+
):
595+
origin_x_size = x.size()
596+
x = x.reshape(-1, origin_x_size[-1])
597+
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
598+
x.contiguous().to(precision),
599+
weight_int4pack,
600+
groupsize,
601+
scales.to(scales_precision),
602+
zeros.to(torch.int8),
603+
).to(dtype=x.dtype)
604+
new_shape = origin_x_size[:-1] + (out_features,)
605+
c = c.reshape(new_shape)
606+
return c
607+
608+
564609
class WeightOnlyInt4Linear(torch.nn.Module):
565610
__constants__ = ["in_features", "out_features"]
566611
in_features: int
@@ -579,6 +624,7 @@ def __init__(
579624
inner_k_tiles: int = 8,
580625
precision: torch.dtype = torch.bfloat16,
581626
scales_precision: torch.dtype = torch.bfloat16,
627+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
582628
) -> None:
583629
super().__init__()
584630
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
@@ -594,6 +640,7 @@ def __init__(
594640
self.inner_k_tiles = inner_k_tiles
595641
self.precision = precision
596642
self.scales_precision = scales_precision
643+
self.zero_point_domain = zero_point_domain
597644

598645
if dtype is not None:
599646
raise ValueError("Please specify 'precision' instead of 'dtype'")
@@ -614,6 +661,18 @@ def __init__(
614661
device=device,
615662
),
616663
)
664+
elif is_device(device.type, "xpu"):
665+
self.register_buffer(
666+
"weight",
667+
torch.zeros(
668+
(
669+
out_features,
670+
in_features // 8,
671+
),
672+
dtype=torch.int32,
673+
device=device,
674+
),
675+
)
617676
else:
618677
self.register_buffer(
619678
"weight",
@@ -629,27 +688,59 @@ def __init__(
629688
),
630689
)
631690
self.dtype = dtype
632-
self.register_buffer(
633-
"scales_and_zeros",
634-
torch.zeros(
635-
(in_features // groupsize, out_features, 2),
636-
dtype=self.scales_precision,
637-
device=device,
638-
),
639-
)
691+
if self.zero_point_domain == ZeroPointDomain.INT:
692+
self.register_buffer(
693+
"scales",
694+
torch.zeros(
695+
(in_features // groupsize, out_features),
696+
dtype=self.scales_precision,
697+
device=device,
698+
),
699+
)
700+
701+
self.register_buffer(
702+
"zeros",
703+
torch.zeros(
704+
(in_features // groupsize, out_features),
705+
dtype=torch.int8,
706+
device=device,
707+
),
708+
)
709+
else:
710+
self.register_buffer(
711+
"scales_and_zeros",
712+
torch.zeros(
713+
(in_features // groupsize, out_features, 2),
714+
dtype=self.scales_precision,
715+
device=device,
716+
),
717+
)
640718

641719
def forward(self, input: torch.Tensor) -> torch.Tensor:
642720
if self.padding:
643721
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
644-
return linear_forward_int4(
645-
input,
646-
self.weight,
647-
self.scales_and_zeros,
648-
self.out_features,
649-
self.groupsize,
650-
self.precision,
651-
self.scales_precision,
652-
)
722+
723+
if self.zero_point_domain != ZeroPointDomain.INT:
724+
return linear_forward_int4(
725+
input,
726+
self.weight,
727+
self.scales_and_zeros,
728+
self.out_features,
729+
self.groupsize,
730+
self.precision,
731+
self.scales_precision,
732+
)
733+
else:
734+
return linear_forward_int4_zero_domain(
735+
input,
736+
self.weight,
737+
self.scales,
738+
self.zeros,
739+
self.out_features,
740+
self.groupsize,
741+
self.precision,
742+
self.scales_precision,
743+
)
653744

654745

655746
def _replace_linear_int4(
@@ -662,6 +753,7 @@ def _replace_linear_int4(
662753
scales_precision: torch.dtype = torch.bfloat16,
663754
linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear,
664755
copy_weights: bool = False,
756+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
665757
):
666758
for name, child in module.named_children():
667759
# TODO: support linear bias
@@ -683,6 +775,7 @@ def _replace_linear_int4(
683775
inner_k_tiles=inner_k_tiles,
684776
precision=precision,
685777
scales_precision=scales_precision,
778+
zero_point_domain = zero_point_domain,
686779
)
687780
# TODO: merge with 8da4w?
688781
# In distributed training, the model may be instantiated
@@ -702,11 +795,17 @@ def _replace_linear_int4(
702795
scales_precision,
703796
linear_class,
704797
copy_weights,
798+
zero_point_domain = zero_point_domain,
705799
)
706800

707801

708802
def replace_linear_int4(
709-
module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func=None
803+
module,
804+
groupsize,
805+
inner_k_tiles,
806+
padding_allowed,
807+
skip_layer_func=None,
808+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
710809
):
711810
_replace_linear_int4(
712811
module,
@@ -715,6 +814,7 @@ def replace_linear_int4(
715814
padding_allowed,
716815
skip_layer_func,
717816
linear_class=WeightOnlyInt4Linear,
817+
zero_point_domain = zero_point_domain,
718818
)
719819

720820

@@ -830,22 +930,24 @@ def __init__(
830930
groupsize=64,
831931
inner_k_tiles=8,
832932
padding_allowed=True,
933+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
833934
device: torch.device = torch.device("cuda"),
834935
):
835936
self.blocksize = blocksize
836937
self.percdamp = percdamp
837938
self.groupsize = groupsize
838939
self.inner_k_tiles = inner_k_tiles
839940
self.padding_allowed = padding_allowed
941+
self.zero_point_domain = zero_point_domain
840942
self.device = device
841943
self.act_fake_quant_func = None
842944
n_bit = 4
843945
self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
844-
w, n_bit, groupsize
946+
w, n_bit, groupsize, zero_point_domain=self.zero_point_domain,
845947
)
846948
self.quantize_func = (
847949
lambda w, qparams: groupwise_affine_quantize_tensor_from_qparams(
848-
w, qparams[0], qparams[1], n_bit, groupsize
950+
w, qparams[0], qparams[1], n_bit, groupsize, zero_point_domain=self.zero_point_domain,
849951
)
850952
)
851953
self.dequantize_func = (
@@ -855,6 +957,7 @@ def __init__(
855957
qparams[1],
856958
n_bit,
857959
groupsize,
960+
zero_point_domain = self.zero_point_domain,
858961
)
859962
)
860963
self.combine_qparams_list_func = lambda qparams_list: [
@@ -886,14 +989,28 @@ def make_names_and_values_dict_func(q, qparams):
886989
F.pad(q, pad=(0, delta_k)), inner_k_tiles
887990
)
888991
scales = qparams[0].to(torch.bfloat16).to(self.device)
889-
zeros = qparams[1].to(torch.bfloat16).to(self.device)
890-
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
891-
# how many new groups we need for padded weight
892-
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
893-
final_s_and_z = F.pad(
894-
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
895-
)
896-
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
992+
if zero_point_domain == ZeroPointDomain.FLOAT:
993+
zeros = qparams[1].to(torch.bfloat16).to(self.device)
994+
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
995+
# how many new groups we need for padded weight
996+
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
997+
final_s_and_z = F.pad(
998+
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
999+
)
1000+
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
1001+
if zero_point_domain == ZeroPointDomain.INT:
1002+
zeros = qparams[1].to(torch.int8).to(self.device)
1003+
scales, zeros = align_tinygemm_scales_and_zeros(scales, zeros)
1004+
# how many new groups we need for padded weight
1005+
delta_groups = new_k // groupsize - scales.shape[0]
1006+
final_s = F.pad(
1007+
scales, pad=(0, 0, 0, delta_groups), value=1
1008+
)
1009+
final_z = F.pad(
1010+
zeros, pad=(0, 0, 0, delta_groups), value=1
1011+
)
1012+
return {"weight": final_q, "scales": final_s, "zeros": final_z}
1013+
8971014

8981015
self.make_names_and_values_dict_func = make_names_and_values_dict_func
8991016
super().__init__()
@@ -905,6 +1022,7 @@ def _convert_for_runtime(self, model):
9051022
self.inner_k_tiles,
9061023
self.padding_allowed,
9071024
skip_layer_func=self.skip_layer_func,
1025+
zero_point_domain = self.zero_point_domain,
9081026
)
9091027
return model
9101028

0 commit comments

Comments
 (0)