Skip to content

Commit 84e9a77

Browse files
authored
Reduce peak gpu memory usage and support moe estimation (#981)
- Reduce peak memory usage by calling clear_memory cosidering performance effort. - Move best_params to CPU and make sure clear memory before moving back. - move loss device to the second card if card_0_in_high_risk - support Deepseek R1 W4A16 tuning with 3 CUDA cards (80GB) (--enable_torch_compile) - support llama3.3 70B W4A16 tuning with 2 Intel GPU cards (24GB)(--enable_torch_compile)
1 parent 284eecd commit 84e9a77

File tree

4 files changed

+221
-69
lines changed

4 files changed

+221
-69
lines changed

auto_round/compressors/base.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14411441

14421442
if is_complex_device_mapping(self.device_map):
14431443
set_auto_device_map_for_block_with_tuning(
1444-
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size
1444+
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device
14451445
)
14461446
# Dispatch model if needed
14471447
if is_complex_device_mapping(self.device_map):
@@ -2332,10 +2332,10 @@ def _quantize_layer(
23322332
if total_loss < best_loss:
23332333
best_loss = total_loss
23342334
if not self.not_use_best_mse:
2335-
best_params = collect_best_params(wrapper_linear)
2335+
best_params = collect_best_params(wrapper_linear, self.cache_device)
23362336
last_best_iter = i
23372337
if self.not_use_best_mse and i == self.iters - 1:
2338-
best_params = collect_best_params(wrapper_linear)
2338+
best_params = collect_best_params(wrapper_linear, self.cache_device)
23392339

23402340
if not self.not_use_best_mse:
23412341
if 0 < self.dynamic_max_gap <= i - last_best_iter:
@@ -2413,6 +2413,7 @@ def _get_current_q_output(
24132413
input_others: dict,
24142414
indices: list[int],
24152415
device: str,
2416+
cache_device: str = "cpu",
24162417
) -> torch.Tensor:
24172418
current_input_ids, current_input_others = self._sampling_inputs(
24182419
input_ids,
@@ -2423,7 +2424,7 @@ def _get_current_q_output(
24232424
share_cache_keys=self.shared_cache_keys,
24242425
)
24252426
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
2426-
return output_q
2427+
return output_q.to(cache_device)
24272428

24282429
def _get_current_num_elm(
24292430
self,
@@ -2458,13 +2459,15 @@ def _quantize_block(
24582459
if is_fp8_linear(m):
24592460
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device)
24602461
set_module(block, n, new_layer)
2461-
2462-
if is_complex_device_mapping(self.device_map):
2463-
set_auto_device_map_for_block_with_tuning(
2462+
# card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights
2463+
# loss_device is used to calculate loss on the second device if available and card_0_in_high_risk
2464+
if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)):
2465+
card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning(
24642466
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
24652467
)
24662468
else:
24672469
block = block.to(device)
2470+
card_0_in_high_risk, loss_device = False, device
24682471

24692472
if is_complex_device_mapping(self.device_map):
24702473
for n, m in block.named_modules():
@@ -2594,21 +2597,21 @@ def _quantize_block(
25942597

25952598
current_output = self._get_current_output(output, indices)
25962599

2597-
current_output = to_device(current_output, device)
2600+
current_output = to_device(current_output, loss_device)
25982601

2599-
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device)
2602+
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device)
26002603

26012604
if self.attention_mask:
26022605
tmp_attention_mask = [self.attention_mask[i] for i in indices]
2603-
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
2606+
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device)
26042607
tmp_attention_mask.unsqueeze_(-1)
26052608
num_elm = torch.sum(tmp_attention_mask).item()
26062609
if num_elm == 0:
26072610
num_elm = 1
26082611
else:
26092612
tmp_attention_mask = 1.0
26102613
if self.amp:
2611-
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2614+
with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype):
26122615
loss = mse_loss( # pylint: disable=not-callable
26132616
output_q * tmp_attention_mask, current_output * tmp_attention_mask
26142617
)
@@ -2619,20 +2622,29 @@ def _quantize_block(
26192622
)
26202623

26212624
total_loss += loss.item() / num_elm
2625+
2626+
if self.low_gpu_mem_usage and card_0_in_high_risk:
2627+
# clear memory to avoid OOM due to memory fragmentation
2628+
clear_memory_if_reached_threshold(threshold=0.5)
2629+
26222630
self._scale_loss_and_backward(scaler, loss)
26232631

2632+
if self.low_gpu_mem_usage and card_0_in_high_risk:
2633+
# clear memory to avoid OOM due to memory fragmentation
2634+
clear_memory_if_reached_threshold(threshold=0.8)
2635+
26242636
if i == 0:
26252637
init_loss = total_loss
26262638

26272639
if total_loss < best_loss:
26282640
best_loss = total_loss
26292641
if not self.not_use_best_mse:
2630-
best_params = collect_best_params(block)
2642+
best_params = collect_best_params(block, self.cache_device)
26312643
# print(f"get better result at iter {i}, the loss is {total_loss}", flush=True)
26322644

26332645
last_best_iter = i
26342646
if self.not_use_best_mse and i == self.iters - 1:
2635-
best_params = collect_best_params(block)
2647+
best_params = collect_best_params(block, self.cache_device)
26362648

26372649
if not self.not_use_best_mse:
26382650
if 0 < self.dynamic_max_gap <= i - last_best_iter:
@@ -2649,6 +2661,8 @@ def _quantize_block(
26492661
f"layers in the block, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}"
26502662
)
26512663
logger.info(dump_info)
2664+
if self.low_gpu_mem_usage:
2665+
clear_memory() # clear cached memory during training
26522666
if len(unquantized_layer_names) != 0:
26532667
logger.info(f"{unquantized_layer_names} have not been quantized")
26542668
with torch.no_grad():
@@ -2659,8 +2673,6 @@ def _quantize_block(
26592673
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")
26602674

26612675
if self.enable_quanted_input:
2662-
if self.low_gpu_mem_usage:
2663-
clear_memory()
26642676
q_outputs = self._get_block_outputs(
26652677
block,
26662678
input_ids,
@@ -2786,13 +2798,16 @@ def _quantize_blocks(
27862798
modules = [get_module(model, n) for n in names]
27872799
m = WrapperMultiblock(modules)
27882800

2801+
m.config = model.config if hasattr(model, "config") else None
27892802
q_input, input_ids = quantize_block(
27902803
m,
27912804
input_ids,
27922805
input_others,
27932806
q_input=q_input,
27942807
device=device,
27952808
)
2809+
if hasattr(model, "config"):
2810+
del m.config
27962811
if self.is_packing_immediate:
27972812
from auto_round.export import PACKING_LAYER_WITH_FORMAT
27982813

auto_round/compressors/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,14 @@ def check_awq_gemm_compatibility(model, bits, group_size, sym, layer_configs=Non
199199
return True, ""
200200

201201

202-
def collect_best_params(block):
202+
def collect_best_params(block, cache_device="cpu"):
203+
"""Collect the best parameters from the block to the specified device."""
203204
params = {}
204205
for n, m in block.named_modules():
205206
if hasattr(m, "orig_layer"):
206207
params[n] = {}
207208
for key in m.params.keys():
208-
params[n][key] = copy.deepcopy(m.params[key].data)
209+
params[n][key] = m.params[key].data.to(cache_device, copy=True)
209210
return params
210211

211212

0 commit comments

Comments
 (0)