Skip to content

Commit 58d31fc

Browse files
authored
fix parse layer config bug (#856)
1 parent 9d2ee69 commit 58d31fc

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

auto_round/compressors/base.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -412,29 +412,26 @@ def _set_device(self, device_map):
412412
def _parse_layer_config(self, layer_config: dict[str, Union[str, dict, QuantizationScheme]]) -> None:
413413
"""Parse and set the layer-wise quantization configuration."""
414414
# Some other quantization configs
415-
self.layer_config = {} if layer_config is None else layer_config
416-
scheme_keys = [f.name for f in fields(QuantizationScheme)]
415+
self.layer_config = copy.deepcopy(layer_config) if layer_config is not None else {}
416+
scheme_keys = {f.name for f in fields(QuantizationScheme)}
417+
417418
for key, item in self.layer_config.items():
418419
if isinstance(item, str):
419-
item = asdict(preset_name_to_scheme(item.upper()))
420-
self.layer_config[key] = item
421-
422-
if isinstance(item, QuantizationScheme):
420+
config = asdict(preset_name_to_scheme(item.upper()))
421+
elif isinstance(item, QuantizationScheme):
423422
config = asdict(item)
424-
tmp_keys = copy.deepcopy(list(config.keys()))
425-
for tmp_key in tmp_keys: ## Pop None value to be overridden
426-
if config[tmp_key] is None:
427-
config.pop(tmp_key)
428-
self.layer_config[key] = config
429423
elif isinstance(item, dict):
430-
item_keys = item.keys()
431-
if item_keys not in scheme_keys:
432-
for item_key in item_keys:
433-
if item_key not in scheme_keys:
434-
raise ValueError(
435-
f"the key {item_key} in layer_config for layer {key} is invalid,"
436-
f" only {scheme_keys} are supported"
437-
)
424+
invalid_keys = set(item) - scheme_keys
425+
if invalid_keys:
426+
raise ValueError(
427+
f"Invalid keys {invalid_keys} in layer_config for layer '{key}', "
428+
f"only {scheme_keys} are supported"
429+
)
430+
config = dict(item)
431+
432+
# Drop None values
433+
config = {k: v for k, v in config.items() if v is not None}
434+
self.layer_config[key] = config
438435

439436
if not self.quant_lm_head or (isinstance(self.scheme, str) and self.scheme.lower().startswith("gguf")):
440437
return

0 commit comments

Comments
 (0)