2
2
from typing import Any , Callable , Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
- from accelerate .utils import align_module_device
5
+ from compressed_tensors .utils import align_module_device , update_offload_parameter
6
6
from loguru import logger
7
7
from pydantic import ConfigDict
8
8
from torch .nn import Module
@@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool:
158
158
159
159
calibration_dataloader = state .data .calib
160
160
161
- self ._set_module_kwargs (state .model , calibration_dataloader )
162
- self ._setup_scale_hooks ()
163
- self ._calibrate (state .model , calibration_dataloader )
164
- self ._concat_collected_activations ()
165
- self ._apply_smoothing (state .model )
161
+ # TODO is it ok to wrap the whole model in this context?
162
+ # I don't think we ever want gradients or to use kv cache
163
+ with calibration_forward_context (state .model ):
164
+ self ._set_module_kwargs (state .model , calibration_dataloader )
165
+ self ._setup_scale_hooks ()
166
+ self ._calibrate (state .model , calibration_dataloader )
167
+ self ._concat_collected_activations ()
168
+ self ._apply_smoothing (state .model )
166
169
167
170
return True
168
171
@@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List):
272
275
" CompressionSession to run the AWQ modifier"
273
276
)
274
277
275
- with calibration_forward_context (model ):
276
- run_calibration_forward (
277
- model ,
278
- calibration_dataloader ,
279
- self .num_calibration_steps ,
280
- self .calibration_function ,
281
- )
278
+ # with calibration_forward_context(model):
279
+ run_calibration_forward (
280
+ model ,
281
+ calibration_dataloader ,
282
+ self .num_calibration_steps ,
283
+ self .calibration_function ,
284
+ )
282
285
283
286
# remove the hooks now that we are done calibrating
284
287
self .remove_hooks ()
@@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module):
356
359
x_mean = (x_sum / num_elements ).to (inp .dtype )
357
360
358
361
# [STEP 3]: Compute output of module
359
- with torch .no_grad ():
360
- fp16_output = self ._forward_input_with_kwargs (
361
- module = module2inspect , inputs = inp , input_kwargs = self .module_kwargs_
362
- )
362
+ fp16_output = self ._forward_input_with_kwargs (
363
+ module = module2inspect , inputs = inp , input_kwargs = self .module_kwargs_
364
+ )
363
365
364
366
# [STEP 4]: Compute loss
365
367
best_scales = self ._compute_best_scale (
@@ -459,14 +461,16 @@ def _compute_best_scale(
459
461
for fc in linears2scale :
460
462
with align_module_device (fc ):
461
463
fc .weight .mul_ (scales_view )
462
- fc .weight .data = (
464
+ update_offload_parameter (
465
+ fc ,
466
+ "weight" ,
463
467
pseudo_quantize_tensor (
464
468
w = fc .weight .data ,
465
469
symmetric = self .symmetric ,
466
470
bit_width = self .bits ,
467
471
group_size = self .group_size ,
468
472
)[0 ]
469
- / scales_view
473
+ / scales_view ,
470
474
)
471
475
472
476
# W * X
@@ -488,7 +492,9 @@ def _compute_best_scale(
488
492
logger .debug (history )
489
493
raise Exception
490
494
491
- assert torch .isnan (best_scales ).sum () == 0 , best_scales
495
+ assert (
496
+ torch .isnan (best_scales ).sum () == 0
497
+ ), f"Nan found in scales: { best_scales } "
492
498
493
499
return best_scales .detach ().cpu ()
494
500
0 commit comments