Skip to content

Commit 12c4984

Browse files
xin3hewenhuach21
andauthored
add self attribution and fix avg_bits error (#956)
* add self attribution and fix avg_bits error --------- Signed-off-by: He, Xin3 <xin3.he@intel.com> Co-authored-by: Wenhua Cheng <wenhua.cheng@intel.com>
1 parent 90c2fb4 commit 12c4984

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

auto_round/auto_scheme/gen_auto_scheme.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,5 @@ def compute_avg_bit_range(self) -> tuple[float, float]:
157157
)[0]
158158
for option in self.auto_scheme.options
159159
]
160-
return min(avg_bits), max(avg_bits)
160+
self.min_avg_bit, self.max_avg_bit = min(avg_bits), max(avg_bits)
161+
return self.min_avg_bit, self.max_avg_bit

auto_round/auto_scheme/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def compute_layer_bits(
160160
n_param = weight.numel()
161161
weight_bits = getattr(layer, "bits", 16)
162162
group_size = getattr(layer, "group_size", 128)
163+
data_type = getattr(layer, "data_type", "int")
164+
is_sym = getattr(layer, "sym", False)
163165
super_group_size = getattr(layer, "super_group_size", None)
164166
super_weight_bits = getattr(layer, "super_bits", None)
165167

@@ -175,7 +177,7 @@ def compute_layer_bits(
175177

176178
# Determine number of groups based on group size
177179
if group_size > 0:
178-
n_group = out_features * (in_features + group_size - 1) // group_size
180+
n_group = out_features * ((in_features + group_size - 1) // group_size)
179181
elif group_size == 0:
180182
n_group = 1
181183
elif group_size == -1:
@@ -185,9 +187,12 @@ def compute_layer_bits(
185187

186188
# Compute auxiliary bits (scales, zero-points, or double quantization)
187189
aux_total_bits = 0
188-
if not super_group_size:
190+
if "mx_fp" in data_type or "nv_fp" in data_type or "fp4" in data_type:
191+
scale_bits = 8
192+
else:
189193
scale_bits = 16
190-
zp_bits = weight_bits
194+
zp_bits = weight_bits if not is_sym or "int" in data_type else 0
195+
if not super_group_size:
191196
aux_total_bits = n_group * (scale_bits + zp_bits)
192197
else:
193198
aux_total_bits += n_group * super_weight_bits * 2

auto_round/compressors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _gen_auto_scheme(
433433

434434
if not self.enable_torch_compile and self.super_bits is None and not scheme.low_gpu_mem_usage:
435435
logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM")
436-
gen_scheme = GenScheme(
436+
self.scheme_generator = GenScheme(
437437
scheme,
438438
self.model,
439439
quant_layer_names,
@@ -443,7 +443,7 @@ def _gen_auto_scheme(
443443
tokenizer=self.tokenizer,
444444
enable_torch_compile=self.enable_torch_compile,
445445
)
446-
layer_config = gen_scheme.get_layer_config()
446+
layer_config = self.scheme_generator.get_layer_config()
447447
return layer_config
448448

449449
def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:

test/test_cuda/test_auto_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_min_target_bits(self):
146146
#
147147
def test_max_target_bits(self):
148148
model_name = "/models/opt-125m"
149-
target_bits = 8.211
149+
target_bits = 8.025
150150
scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16"))
151151
ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
152152
model, layer_config = ar.quantize()

0 commit comments

Comments
 (0)