27
27
from .quant_primitives import (
28
28
MappingType ,
29
29
dequantize_affine ,
30
+ ZeroPointDomain ,
30
31
)
31
32
from .unified import Quantizer
32
33
from .utils import (
38
39
groupwise_affine_quantize_tensor ,
39
40
groupwise_affine_quantize_tensor_from_qparams ,
40
41
pack_tinygemm_scales_and_zeros ,
42
+ align_tinygemm_scales_and_zeros ,
41
43
per_token_dynamic_quant ,
42
44
)
43
45
@@ -75,18 +77,19 @@ def __init__(
75
77
percdamp = 0.01 ,
76
78
groupsize = 128 ,
77
79
):
80
+ self .device = self .get_device (model , inputs )
78
81
self .id_to_name = {
79
82
id (value ): name for name , value in dict (model .named_parameters ()).items ()
80
83
}
81
84
82
85
# trace model for one input
83
- one_input = [multi .values [0 ]. cpu () for multi in inputs ] # pyre-ignore[16]
86
+ one_input = [multi .values [0 ] for multi in inputs ] # pyre-ignore[16]
84
87
# needed for GPTQ on the torchao llama model
85
88
import torchao
86
89
87
90
torchao ._models .llama .model .use_index_put_for_kv_cache = True
88
91
exported_model = torch ._dynamo .export (
89
- model . cpu () , aten_graph = True , pre_dispatch = True , tracing_mode = "fake"
92
+ model , aten_graph = True , pre_dispatch = True , tracing_mode = "fake"
90
93
)(* one_input )
91
94
super ().__init__ (exported_model .graph_module )
92
95
@@ -100,6 +103,19 @@ def __init__(
100
103
self .inputs = inputs
101
104
self .gptq_done = False
102
105
self .debug = False
106
+
107
+
108
+ def get_device (self , model , inputs : _MultiInput ):
109
+ for name , param in model .named_parameters ():
110
+ if isinstance (param , torch .Tensor ):
111
+ return param .device
112
+
113
+ for multi in inputs :
114
+ if isinstance (multi .values [0 ], torch .Tensor ):
115
+ return multi .values [0 ].device
116
+
117
+ return torch .device ("cpu" )
118
+
103
119
104
120
def configure_quantization_mode (
105
121
self ,
@@ -163,16 +179,16 @@ def get_quantized_state_dict(self):
163
179
return quantized_state_dict
164
180
165
181
def call_function (self , target , args , kwargs , already_quantized = False ): # noqa: C901
166
- def tensors_to_cuda (args ):
182
+ def tensors_to_device (args ):
167
183
new_args = []
168
184
for x in args :
169
- new_args .append (x .cuda ( ) if isinstance (x , torch .Tensor ) else x )
185
+ new_args .append (x .to ( self . device ) if isinstance (x , torch .Tensor ) else x )
170
186
return new_args
171
187
172
188
# flatten args and kwargs together
173
189
flat_args , spec = tree_flatten ((args , kwargs ))
174
190
# move all single tensors to cuda, will move _MultiInputs to cuda one at a time
175
- flat_args = tensors_to_cuda (flat_args )
191
+ flat_args = tensors_to_device (flat_args )
176
192
177
193
has_multi_input = _MultiInput in [type (x ) for x in flat_args ]
178
194
if has_multi_input :
@@ -212,7 +228,7 @@ def tensors_to_cuda(args):
212
228
total_batches = 0
213
229
214
230
for inp in transposed_args :
215
- inp = tensors_to_cuda (inp )
231
+ inp = tensors_to_device (inp )
216
232
cur_args , cur_kwargs = tree_unflatten (inp , spec )
217
233
218
234
if quantize_linear : # calculate H instead of output (will run the linear eventually with updated weight)
@@ -283,7 +299,7 @@ def SQNR(x, y):
283
299
"SQNR for QDQ (this should be inf)" , SQNR (DQ , DQ_after )
284
300
) # matches
285
301
print (
286
- "SQNR for weight (can be low)" , SQNR (W , DQ .cuda ( ))
302
+ "SQNR for weight (can be low)" , SQNR (W , DQ .to ( self . device ))
287
303
) # fine to not match
288
304
print (
289
305
"SQNR for output with GPTQ (hopefully 35+)" ,
@@ -385,7 +401,12 @@ def faster_quant(self, H, W):
385
401
386
402
W [:, i2 :] -= Err1 .to (Hinv .dtype ).matmul (Hinv [i1 :i2 , i2 :])
387
403
388
- torch .cuda .synchronize ()
404
+ if 'cuda' in self .device .type :
405
+ torch .cuda .synchronize ()
406
+ elif 'xpu' in self .device .type :
407
+ torch .xpu .synchronize ()
408
+ else :
409
+ pass
389
410
390
411
if all_qparams == []:
391
412
all_qparams .append (cur_qparams )
@@ -561,6 +582,30 @@ def linear_forward_int4(
561
582
return c
562
583
563
584
585
+ def linear_forward_int4_zero_domain (
586
+ x : torch .Tensor ,
587
+ weight_int4pack : torch .Tensor ,
588
+ scales : torch .Tensor ,
589
+ zeros : torch .Tensor ,
590
+ out_features : int ,
591
+ groupsize : int ,
592
+ precision : torch .dtype = torch .bfloat16 ,
593
+ scales_precision : torch .dtype = torch .bfloat16 ,
594
+ ):
595
+ origin_x_size = x .size ()
596
+ x = x .reshape (- 1 , origin_x_size [- 1 ])
597
+ c = torch .ops .aten ._weight_int4pack_mm_with_scales_and_zeros (
598
+ x .contiguous ().to (precision ),
599
+ weight_int4pack ,
600
+ groupsize ,
601
+ scales .to (scales_precision ),
602
+ zeros .to (torch .int8 ),
603
+ ).to (dtype = x .dtype )
604
+ new_shape = origin_x_size [:- 1 ] + (out_features ,)
605
+ c = c .reshape (new_shape )
606
+ return c
607
+
608
+
564
609
class WeightOnlyInt4Linear (torch .nn .Module ):
565
610
__constants__ = ["in_features" , "out_features" ]
566
611
in_features : int
@@ -579,6 +624,7 @@ def __init__(
579
624
inner_k_tiles : int = 8 ,
580
625
precision : torch .dtype = torch .bfloat16 ,
581
626
scales_precision : torch .dtype = torch .bfloat16 ,
627
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .FLOAT ,
582
628
) -> None :
583
629
super ().__init__ ()
584
630
self .padding = not _check_linear_int4_k (in_features , groupsize , inner_k_tiles )
@@ -594,6 +640,7 @@ def __init__(
594
640
self .inner_k_tiles = inner_k_tiles
595
641
self .precision = precision
596
642
self .scales_precision = scales_precision
643
+ self .zero_point_domain = zero_point_domain
597
644
598
645
if dtype is not None :
599
646
raise ValueError ("Please specify 'precision' instead of 'dtype'" )
@@ -614,6 +661,18 @@ def __init__(
614
661
device = device ,
615
662
),
616
663
)
664
+ elif is_device (device .type , "xpu" ):
665
+ self .register_buffer (
666
+ "weight" ,
667
+ torch .zeros (
668
+ (
669
+ out_features ,
670
+ in_features // 8 ,
671
+ ),
672
+ dtype = torch .int32 ,
673
+ device = device ,
674
+ ),
675
+ )
617
676
else :
618
677
self .register_buffer (
619
678
"weight" ,
@@ -629,27 +688,59 @@ def __init__(
629
688
),
630
689
)
631
690
self .dtype = dtype
632
- self .register_buffer (
633
- "scales_and_zeros" ,
634
- torch .zeros (
635
- (in_features // groupsize , out_features , 2 ),
636
- dtype = self .scales_precision ,
637
- device = device ,
638
- ),
639
- )
691
+ if self .zero_point_domain == ZeroPointDomain .INT :
692
+ self .register_buffer (
693
+ "scales" ,
694
+ torch .zeros (
695
+ (in_features // groupsize , out_features ),
696
+ dtype = self .scales_precision ,
697
+ device = device ,
698
+ ),
699
+ )
700
+
701
+ self .register_buffer (
702
+ "zeros" ,
703
+ torch .zeros (
704
+ (in_features // groupsize , out_features ),
705
+ dtype = torch .int8 ,
706
+ device = device ,
707
+ ),
708
+ )
709
+ else :
710
+ self .register_buffer (
711
+ "scales_and_zeros" ,
712
+ torch .zeros (
713
+ (in_features // groupsize , out_features , 2 ),
714
+ dtype = self .scales_precision ,
715
+ device = device ,
716
+ ),
717
+ )
640
718
641
719
def forward (self , input : torch .Tensor ) -> torch .Tensor :
642
720
if self .padding :
643
721
input = F .pad (input , pad = (0 , self .in_features - self .origin_in_features ))
644
- return linear_forward_int4 (
645
- input ,
646
- self .weight ,
647
- self .scales_and_zeros ,
648
- self .out_features ,
649
- self .groupsize ,
650
- self .precision ,
651
- self .scales_precision ,
652
- )
722
+
723
+ if self .zero_point_domain != ZeroPointDomain .INT :
724
+ return linear_forward_int4 (
725
+ input ,
726
+ self .weight ,
727
+ self .scales_and_zeros ,
728
+ self .out_features ,
729
+ self .groupsize ,
730
+ self .precision ,
731
+ self .scales_precision ,
732
+ )
733
+ else :
734
+ return linear_forward_int4_zero_domain (
735
+ input ,
736
+ self .weight ,
737
+ self .scales ,
738
+ self .zeros ,
739
+ self .out_features ,
740
+ self .groupsize ,
741
+ self .precision ,
742
+ self .scales_precision ,
743
+ )
653
744
654
745
655
746
def _replace_linear_int4 (
@@ -662,6 +753,7 @@ def _replace_linear_int4(
662
753
scales_precision : torch .dtype = torch .bfloat16 ,
663
754
linear_class : Type [torch .nn .Module ] = WeightOnlyInt4Linear ,
664
755
copy_weights : bool = False ,
756
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .FLOAT ,
665
757
):
666
758
for name , child in module .named_children ():
667
759
# TODO: support linear bias
@@ -683,6 +775,7 @@ def _replace_linear_int4(
683
775
inner_k_tiles = inner_k_tiles ,
684
776
precision = precision ,
685
777
scales_precision = scales_precision ,
778
+ zero_point_domain = zero_point_domain ,
686
779
)
687
780
# TODO: merge with 8da4w?
688
781
# In distributed training, the model may be instantiated
@@ -702,11 +795,17 @@ def _replace_linear_int4(
702
795
scales_precision ,
703
796
linear_class ,
704
797
copy_weights ,
798
+ zero_point_domain = zero_point_domain ,
705
799
)
706
800
707
801
708
802
def replace_linear_int4 (
709
- module , groupsize , inner_k_tiles , padding_allowed , skip_layer_func = None
803
+ module ,
804
+ groupsize ,
805
+ inner_k_tiles ,
806
+ padding_allowed ,
807
+ skip_layer_func = None ,
808
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .FLOAT ,
710
809
):
711
810
_replace_linear_int4 (
712
811
module ,
@@ -715,6 +814,7 @@ def replace_linear_int4(
715
814
padding_allowed ,
716
815
skip_layer_func ,
717
816
linear_class = WeightOnlyInt4Linear ,
817
+ zero_point_domain = zero_point_domain ,
718
818
)
719
819
720
820
@@ -830,22 +930,24 @@ def __init__(
830
930
groupsize = 64 ,
831
931
inner_k_tiles = 8 ,
832
932
padding_allowed = True ,
933
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .FLOAT ,
833
934
device : torch .device = torch .device ("cuda" ),
834
935
):
835
936
self .blocksize = blocksize
836
937
self .percdamp = percdamp
837
938
self .groupsize = groupsize
838
939
self .inner_k_tiles = inner_k_tiles
839
940
self .padding_allowed = padding_allowed
941
+ self .zero_point_domain = zero_point_domain
840
942
self .device = device
841
943
self .act_fake_quant_func = None
842
944
n_bit = 4
843
945
self .get_qparams_func = lambda w : get_groupwise_affine_qparams (
844
- w , n_bit , groupsize
946
+ w , n_bit , groupsize , zero_point_domain = self . zero_point_domain ,
845
947
)
846
948
self .quantize_func = (
847
949
lambda w , qparams : groupwise_affine_quantize_tensor_from_qparams (
848
- w , qparams [0 ], qparams [1 ], n_bit , groupsize
950
+ w , qparams [0 ], qparams [1 ], n_bit , groupsize , zero_point_domain = self . zero_point_domain ,
849
951
)
850
952
)
851
953
self .dequantize_func = (
@@ -855,6 +957,7 @@ def __init__(
855
957
qparams [1 ],
856
958
n_bit ,
857
959
groupsize ,
960
+ zero_point_domain = self .zero_point_domain ,
858
961
)
859
962
)
860
963
self .combine_qparams_list_func = lambda qparams_list : [
@@ -886,14 +989,28 @@ def make_names_and_values_dict_func(q, qparams):
886
989
F .pad (q , pad = (0 , delta_k )), inner_k_tiles
887
990
)
888
991
scales = qparams [0 ].to (torch .bfloat16 ).to (self .device )
889
- zeros = qparams [1 ].to (torch .bfloat16 ).to (self .device )
890
- scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
891
- # how many new groups we need for padded weight
892
- delta_groups = new_k // groupsize - scales_and_zeros .shape [0 ]
893
- final_s_and_z = F .pad (
894
- scales_and_zeros , pad = (0 , 0 , 0 , 0 , 0 , delta_groups ), value = 1
895
- )
896
- return {"weight" : final_q , "scales_and_zeros" : final_s_and_z }
992
+ if zero_point_domain == ZeroPointDomain .FLOAT :
993
+ zeros = qparams [1 ].to (torch .bfloat16 ).to (self .device )
994
+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
995
+ # how many new groups we need for padded weight
996
+ delta_groups = new_k // groupsize - scales_and_zeros .shape [0 ]
997
+ final_s_and_z = F .pad (
998
+ scales_and_zeros , pad = (0 , 0 , 0 , 0 , 0 , delta_groups ), value = 1
999
+ )
1000
+ return {"weight" : final_q , "scales_and_zeros" : final_s_and_z }
1001
+ if zero_point_domain == ZeroPointDomain .INT :
1002
+ zeros = qparams [1 ].to (torch .int8 ).to (self .device )
1003
+ scales , zeros = align_tinygemm_scales_and_zeros (scales , zeros )
1004
+ # how many new groups we need for padded weight
1005
+ delta_groups = new_k // groupsize - scales .shape [0 ]
1006
+ final_s = F .pad (
1007
+ scales , pad = (0 , 0 , 0 , delta_groups ), value = 1
1008
+ )
1009
+ final_z = F .pad (
1010
+ zeros , pad = (0 , 0 , 0 , delta_groups ), value = 1
1011
+ )
1012
+ return {"weight" : final_q , "scales" : final_s , "zeros" : final_z }
1013
+
897
1014
898
1015
self .make_names_and_values_dict_func = make_names_and_values_dict_func
899
1016
super ().__init__ ()
@@ -905,6 +1022,7 @@ def _convert_for_runtime(self, model):
905
1022
self .inner_k_tiles ,
906
1023
self .padding_allowed ,
907
1024
skip_layer_func = self .skip_layer_func ,
1025
+ zero_point_domain = self .zero_point_domain ,
908
1026
)
909
1027
return model
910
1028
0 commit comments