Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 18 additions & 65 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,82 +1217,35 @@ def get_imatrix_hook(module, input, output):
hooks = register_act_hook(model)

try:
# Move model to target device
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
dispatch_model(self.model, self.model.hf_device_map)
else:
model = model.to(self.device)
if _is_fp8_model(self.model):
convert_fp8_model_to_16b_model(self.model, self.amp_dtype)
cnt = 0

# Run forward pass to accumulate imatrix
for data in self.dataloader:
cnt += data["input_ids"].shape[0]
data = to_device(data, self.device)
model(**data)
if cnt >= self.nsamples:
break

# Remove hooks after data collection
for hook in hooks:
hook.remove()

# Normalize imatrix by count
for _, module in model.named_modules():
if hasattr(module, "imatrix"):
module.imatrix /= module.imatrix_cnt
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
import accelerate

accelerate.hooks.remove_hook_from_submodules(model)
# Perform quantization using RTN
pbar = tqdm(all_to_quantized_module_names)
block_names_cnt = len(flatten_list(get_block_names(self.model, True)))
clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt
if clear_mem_freq == 0:
clear_mem_freq = 1
cnt = 1
for name in pbar:
pbar.set_description(f"Quantizing {name}")
self._quantize_layer_via_rtn(name)
if cnt % clear_mem_freq == 0:
clear_memory()
cnt = 1
cnt += 1
model = model.to("cpu")
clear_memory()
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
except RuntimeError as e:
cuda_error_msg = traceback.format_exc()
try:
logger.error(cuda_error_msg)
# Final fallback: warn and use CPU-only quantization
logger.warning(
"Fallback to CPU. "
"Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`."
)
model = model.to("cpu")
clear_memory()
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
import accelerate

accelerate.hooks.remove_hook_from_submodules(model)
# Fallback: out-of-memory → try CPU blockwise quantization
logger.warning("Out of VRAM, falling back to blockwise quantization. Accuracy may degrade.")
model = model.to("cpu")
clear_memory()
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
except RuntimeError as e:
cuda_error_msg = traceback.format_exc()
try:
logger.error(cuda_error_msg)
# Final fallback: warn and use CPU-only quantization
logger.warning(
"Fallback to CPU. "
"Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`."
)
model = model.to("cpu")
clear_memory()
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
import accelerate

accelerate.hooks.remove_hook_from_submodules(model)

orig_device = self.device
self.device = "cpu"
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
self.device = orig_device
except Exception as e:
raise
orig_device = self.device
self.device = "cpu"
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
self.device = orig_device
except Exception as e:
raise
finally:
# Always remove hooks
for hook in hooks:
Expand Down
Loading