Skip to content

Commit a4f2a69

Browse files
committed
remove unused
1 parent 8cc7364 commit a4f2a69

File tree

1 file changed

+5
-36
lines changed

1 file changed

+5
-36
lines changed

gptqmodel/nn_modules/qlinear/torch_fused_awq.py

Lines changed: 5 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def __init__(
7272
**kwargs,
7373
)
7474
if register_buffers:
75-
qweight_shape = self._awq_qweight_shape()
75+
qweight_shape = self.awq_qweight_shape()
7676
group_size = max(int(self.group_size), 1)
77-
group_rows = self._awq_group_count()
77+
group_rows = self.awq_group_count()
7878
pack_cols = qweight_shape[1]
7979

8080
self.register_buffer(
@@ -96,52 +96,21 @@ def __init__(
9696
else:
9797
self.bias = None
9898

99-
def _awq_qweight_shape(self):
99+
def awq_qweight_shape(self):
100100
pack_cols = max(1, self.out_features // self.pack_factor)
101101
return self.in_features, pack_cols
102102

103-
def _awq_group_count(self):
103+
def awq_group_count(self):
104104
group_size = max(int(self.group_size), 1)
105105
return max(1, math.ceil(self.in_features / group_size))
106106

107-
# def _load_from_state_dict(
108-
# self,
109-
# state_dict,
110-
# prefix,
111-
# local_metadata,
112-
# strict,
113-
# missing_keys,
114-
# unexpected_keys,
115-
# error_msgs,
116-
# ):
117-
# self.register_awq_buffers()
118-
# super()._load_from_state_dict(
119-
# state_dict,
120-
# prefix,
121-
# local_metadata,
122-
# strict,
123-
# missing_keys,
124-
# unexpected_keys,
125-
# error_msgs,
126-
# )
127-
# qweight = getattr(self, "qweight", None)
128-
# if torch.is_tensor(qweight):
129-
# expected_shape = self._awq_qweight_shape()
130-
# if tuple(qweight.shape) != expected_shape:
131-
# raise ValueError(
132-
# f"{self.__class__.__name__} only loads AWQ-formatted qweight tensors with "
133-
# f"shape {expected_shape}, but received {tuple(qweight.shape)}."
134-
# )
135-
# if qweight.dtype != self.pack_dtype:
136-
# self.qweight = qweight.to(dtype=self.pack_dtype).contiguous()
137-
138107
def transform_cpu_awq(self, dtype):
139108
src_scales = self.scales
140109
if src_scales.dtype != torch.float16:
141110
src_scales = src_scales.to(torch.float16)
142111
src_scales = src_scales.contiguous()
143112

144-
# Cache unpacked AWQ tensors
113+
# Unpack AWQ tensors
145114
iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits)
146115
iweight, izeros = reverse_awq_order(iweight, izeros, self.bits)
147116
max_val = (1 << self.bits) - 1

0 commit comments

Comments
 (0)