Skip to content

Commit 70e4b04

Browse files
committed
inline methods
1 parent a4f2a69 commit 70e4b04

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

gptqmodel/nn_modules/qlinear/torch_fused_awq.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ def __init__(
7272
**kwargs,
7373
)
7474
if register_buffers:
75-
qweight_shape = self.awq_qweight_shape()
75+
# AWQ packs each input row into pack_factor-wide columns for int4 lanes.
76+
pack_cols = max(1, self.out_features // self.pack_factor)
77+
qweight_shape = (self.in_features, pack_cols)
7678
group_size = max(int(self.group_size), 1)
77-
group_rows = self.awq_group_count()
78-
pack_cols = qweight_shape[1]
79+
# Each group holds group_size input rows; ceil ensures trailing rows are captured.
80+
group_rows = max(1, math.ceil(self.in_features / group_size))
7981

8082
self.register_buffer(
8183
"qweight",
@@ -96,14 +98,6 @@ def __init__(
9698
else:
9799
self.bias = None
98100

99-
def awq_qweight_shape(self):
100-
pack_cols = max(1, self.out_features // self.pack_factor)
101-
return self.in_features, pack_cols
102-
103-
def awq_group_count(self):
104-
group_size = max(int(self.group_size), 1)
105-
return max(1, math.ceil(self.in_features / group_size))
106-
107101
def transform_cpu_awq(self, dtype):
108102
src_scales = self.scales
109103
if src_scales.dtype != torch.float16:

0 commit comments

Comments
 (0)