File tree Expand file tree Collapse file tree 1 file changed +5
-11
lines changed
gptqmodel/nn_modules/qlinear Expand file tree Collapse file tree 1 file changed +5
-11
lines changed Original file line number Diff line number Diff line change @@ -72,10 +72,12 @@ def __init__(
7272 ** kwargs ,
7373 )
7474 if register_buffers :
75- qweight_shape = self .awq_qweight_shape ()
75+ # AWQ packs each input row into pack_factor-wide columns for int4 lanes.
76+ pack_cols = max (1 , self .out_features // self .pack_factor )
77+ qweight_shape = (self .in_features , pack_cols )
7678 group_size = max (int (self .group_size ), 1 )
77- group_rows = self . awq_group_count ()
78- pack_cols = qweight_shape [ 1 ]
79+ # Each group holds group_size input rows; ceil ensures trailing rows are captured.
80+ group_rows = max ( 1 , math . ceil ( self . in_features / group_size ))
7981
8082 self .register_buffer (
8183 "qweight" ,
@@ -96,14 +98,6 @@ def __init__(
9698 else :
9799 self .bias = None
98100
99- def awq_qweight_shape (self ):
100- pack_cols = max (1 , self .out_features // self .pack_factor )
101- return self .in_features , pack_cols
102-
103- def awq_group_count (self ):
104- group_size = max (int (self .group_size ), 1 )
105- return max (1 , math .ceil (self .in_features / group_size ))
106-
107101 def transform_cpu_awq (self , dtype ):
108102 src_scales = self .scales
109103 if src_scales .dtype != torch .float16 :
You can’t perform that action at this time.
0 commit comments