Skip to content

Commit 29df357

Browse files
committed
use low_gpu_mem_usage to cache best params
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent 0055d30 commit 29df357

File tree

4 files changed

+16
-10
lines changed

4 files changed

+16
-10
lines changed

auto_round/compressors/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,10 +2321,10 @@ def _quantize_layer(
23212321
if total_loss < best_loss:
23222322
best_loss = total_loss
23232323
if not self.not_use_best_mse:
2324-
best_params = collect_best_params(wrapper_linear)
2324+
best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage)
23252325
last_best_iter = i
23262326
if self.not_use_best_mse and i == self.iters - 1:
2327-
best_params = collect_best_params(wrapper_linear)
2327+
best_params = collect_best_params(wrapper_linear, self.low_gpu_mem_usage)
23282328

23292329
if not self.not_use_best_mse:
23302330
if 0 < self.dynamic_max_gap <= i - last_best_iter:
@@ -2603,21 +2603,23 @@ def _quantize_block(
26032603
)
26042604

26052605
total_loss += loss.item() / num_elm
2606+
# Sometimes the cached memory is not released during training and cause OOM
2607+
if self.low_gpu_mem_usage:
2608+
clear_memory_if_reached_threshold(threshold=0.85)
26062609
self._scale_loss_and_backward(scaler, loss)
2607-
clear_memory_if_reached_threshold(threshold=0.85)
26082610

26092611
if i == 0:
26102612
init_loss = total_loss
26112613

26122614
if total_loss < best_loss:
26132615
best_loss = total_loss
26142616
if not self.not_use_best_mse:
2615-
best_params = collect_best_params(block)
2617+
best_params = collect_best_params(block, self.low_gpu_mem_usage)
26162618
# print(f"get better result at iter {i}, the loss is {total_loss}", flush=True)
26172619

26182620
last_best_iter = i
26192621
if self.not_use_best_mse and i == self.iters - 1:
2620-
best_params = collect_best_params(block)
2622+
best_params = collect_best_params(block, self.low_gpu_mem_usage)
26212623

26222624
if not self.not_use_best_mse:
26232625
if 0 < self.dynamic_max_gap <= i - last_best_iter:
@@ -2634,6 +2636,8 @@ def _quantize_block(
26342636
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
26352637
)
26362638
logger.info(dump_info)
2639+
if self.low_gpu_mem_usage:
2640+
clear_memory() # clear cached memory during training
26372641
if len(unquantized_layer_names) != 0:
26382642
logger.info(f"{unquantized_layer_names} have not been quantized")
26392643
with torch.no_grad():
@@ -2644,7 +2648,6 @@ def _quantize_block(
26442648
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
26452649

26462650
if self.enable_quanted_input:
2647-
clear_memory()
26482651
q_outputs = self._get_block_outputs(
26492652
block,
26502653
input_ids,

auto_round/compressors/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non
199199
return True, ""
200200

201201

202-
def collect_best_params(block):
202+
def collect_best_params(block, low_gpu_mem_usage: bool = False):
203203
params = {}
204204
for n, m in block.named_modules():
205205
if hasattr(m, "orig_layer"):
206206
params[n] = {}
207207
for key in m.params.keys():
208208
params[n][key] = copy.deepcopy(m.params[key].data)
209+
if low_gpu_mem_usage:
210+
params[n][key] = params[n][key].cpu()
209211
return params
210212

211213

auto_round/utils/device.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def clear_memory_if_reached_threshold(threshold=0.85):
431431
elif hasattr(torch, "xpu") and torch.xpu.is_available():
432432
name, device_api = "XPU", torch.xpu
433433
else:
434-
return
434+
return False
435435

436436
num_devices = device_api.device_count()
437437
for i in range(num_devices):
@@ -452,6 +452,7 @@ def clear_memory_if_reached_threshold(threshold=0.85):
452452
return True
453453
except Exception as e:
454454
logger.warning_once(f"Failed to check memory for {name} device {i}: {e}")
455+
return False
455456

456457

457458
def check_memory_availability(device, inputs, weight, org_seqlen, org_bs):

auto_round/wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"):
550550
def unwrapper(self, best_params):
551551
if best_params is None:
552552
return self.orig_layer
553-
v = best_params["v"]
553+
v = best_params["v"].to(self.device)
554554
weight_q, _, _ = self.quant_func(
555555
self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh
556556
)
@@ -601,7 +601,7 @@ def __init__(self, orig_layer, bit=4, group_size=-1, device="cpu"):
601601
def unwrapper(self, best_params):
602602
if best_params is None:
603603
return self.orig_layer
604-
v = best_params["v"]
604+
v = best_params["v"].to(self.device)
605605
weight_q, _, _ = self.quant_func(
606606
self.orig_layer.weight, self.bits, self.group_size, v, q_scale_thresh=self.q_scale_thresh
607607
)

0 commit comments

Comments
 (0)