@@ -72,9 +72,9 @@ def __init__(
7272 ** kwargs ,
7373 )
7474 if register_buffers :
75- qweight_shape = self ._awq_qweight_shape ()
75+ qweight_shape = self .awq_qweight_shape ()
7676 group_size = max (int (self .group_size ), 1 )
77- group_rows = self ._awq_group_count ()
77+ group_rows = self .awq_group_count ()
7878 pack_cols = qweight_shape [1 ]
7979
8080 self .register_buffer (
@@ -96,52 +96,21 @@ def __init__(
9696 else :
9797 self .bias = None
9898
99- def _awq_qweight_shape (self ):
99+ def awq_qweight_shape (self ):
100100 pack_cols = max (1 , self .out_features // self .pack_factor )
101101 return self .in_features , pack_cols
102102
103- def _awq_group_count (self ):
103+ def awq_group_count (self ):
104104 group_size = max (int (self .group_size ), 1 )
105105 return max (1 , math .ceil (self .in_features / group_size ))
106106
107- # def _load_from_state_dict(
108- # self,
109- # state_dict,
110- # prefix,
111- # local_metadata,
112- # strict,
113- # missing_keys,
114- # unexpected_keys,
115- # error_msgs,
116- # ):
117- # self.register_awq_buffers()
118- # super()._load_from_state_dict(
119- # state_dict,
120- # prefix,
121- # local_metadata,
122- # strict,
123- # missing_keys,
124- # unexpected_keys,
125- # error_msgs,
126- # )
127- # qweight = getattr(self, "qweight", None)
128- # if torch.is_tensor(qweight):
129- # expected_shape = self._awq_qweight_shape()
130- # if tuple(qweight.shape) != expected_shape:
131- # raise ValueError(
132- # f"{self.__class__.__name__} only loads AWQ-formatted qweight tensors with "
133- # f"shape {expected_shape}, but received {tuple(qweight.shape)}."
134- # )
135- # if qweight.dtype != self.pack_dtype:
136- # self.qweight = qweight.to(dtype=self.pack_dtype).contiguous()
137-
138107 def transform_cpu_awq (self , dtype ):
139108 src_scales = self .scales
140109 if src_scales .dtype != torch .float16 :
141110 src_scales = src_scales .to (torch .float16 )
142111 src_scales = src_scales .contiguous ()
143112
144- # Cache unpacked AWQ tensors
113+ # Unpack AWQ tensors
145114 iweight , izeros = unpack_awq (self .qweight , self .qzeros , self .bits )
146115 iweight , izeros = reverse_awq_order (iweight , izeros , self .bits )
147116 max_val = (1 << self .bits ) - 1
0 commit comments