Skip to content

Commit 28f8bca

Browse files
working with larger num_calibration_samples
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 5cb055c commit 28f8bca

File tree

1 file changed

+26
-20
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+26
-20
lines changed

src/llmcompressor/modifiers/awq/base.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from accelerate.utils import align_module_device
5+
from compressed_tensors.utils import align_module_device, update_offload_parameter
66
from loguru import logger
77
from pydantic import ConfigDict
88
from torch.nn import Module
@@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool:
158158

159159
calibration_dataloader = state.data.calib
160160

161-
self._set_module_kwargs(state.model, calibration_dataloader)
162-
self._setup_scale_hooks()
163-
self._calibrate(state.model, calibration_dataloader)
164-
self._concat_collected_activations()
165-
self._apply_smoothing(state.model)
161+
# TODO is it ok to wrap the whole model in this context?
162+
# I don't think we ever want gradients or to use kv cache
163+
with calibration_forward_context(state.model):
164+
self._set_module_kwargs(state.model, calibration_dataloader)
165+
self._setup_scale_hooks()
166+
self._calibrate(state.model, calibration_dataloader)
167+
self._concat_collected_activations()
168+
self._apply_smoothing(state.model)
166169

167170
return True
168171

@@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
272275
" CompressionSession to run the AWQ modifier"
273276
)
274277

275-
with calibration_forward_context(model):
276-
run_calibration_forward(
277-
model,
278-
calibration_dataloader,
279-
self.num_calibration_steps,
280-
self.calibration_function,
281-
)
278+
# with calibration_forward_context(model):
279+
run_calibration_forward(
280+
model,
281+
calibration_dataloader,
282+
self.num_calibration_steps,
283+
self.calibration_function,
284+
)
282285

283286
# remove the hooks now that we are done calibrating
284287
self.remove_hooks()
@@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module):
356359
x_mean = (x_sum / num_elements).to(inp.dtype)
357360

358361
# [STEP 3]: Compute output of module
359-
with torch.no_grad():
360-
fp16_output = self._forward_input_with_kwargs(
361-
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
362-
)
362+
fp16_output = self._forward_input_with_kwargs(
363+
module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_
364+
)
363365

364366
# [STEP 4]: Compute loss
365367
best_scales = self._compute_best_scale(
@@ -459,14 +461,16 @@ def _compute_best_scale(
459461
for fc in linears2scale:
460462
with align_module_device(fc):
461463
fc.weight.mul_(scales_view)
462-
fc.weight.data = (
464+
update_offload_parameter(
465+
fc,
466+
"weight",
463467
pseudo_quantize_tensor(
464468
w=fc.weight.data,
465469
symmetric=self.symmetric,
466470
bit_width=self.bits,
467471
group_size=self.group_size,
468472
)[0]
469-
/ scales_view
473+
/ scales_view,
470474
)
471475

472476
# W * X
@@ -488,7 +492,9 @@ def _compute_best_scale(
488492
logger.debug(history)
489493
raise Exception
490494

491-
assert torch.isnan(best_scales).sum() == 0, best_scales
495+
assert (
496+
torch.isnan(best_scales).sum() == 0
497+
), f"Nan found in scales: {best_scales}"
492498

493499
return best_scales.detach().cpu()
494500

0 commit comments

Comments
 (0)