@@ -1441,7 +1441,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14411441
14421442 if is_complex_device_mapping (self .device_map ):
14431443 set_auto_device_map_for_block_with_tuning (
1444- block , self .device_map , input_ids , self .low_gpu_mem_usage , self .batch_size
1444+ block , self .device_map , input_ids , self .low_gpu_mem_usage , self .batch_size , self . device
14451445 )
14461446 # Dispatch model if needed
14471447 if is_complex_device_mapping (self .device_map ):
@@ -2332,10 +2332,10 @@ def _quantize_layer(
23322332 if total_loss < best_loss :
23332333 best_loss = total_loss
23342334 if not self .not_use_best_mse :
2335- best_params = collect_best_params (wrapper_linear )
2335+ best_params = collect_best_params (wrapper_linear , self . cache_device )
23362336 last_best_iter = i
23372337 if self .not_use_best_mse and i == self .iters - 1 :
2338- best_params = collect_best_params (wrapper_linear )
2338+ best_params = collect_best_params (wrapper_linear , self . cache_device )
23392339
23402340 if not self .not_use_best_mse :
23412341 if 0 < self .dynamic_max_gap <= i - last_best_iter :
@@ -2413,6 +2413,7 @@ def _get_current_q_output(
24132413 input_others : dict ,
24142414 indices : list [int ],
24152415 device : str ,
2416+ cache_device : str = "cpu" ,
24162417 ) -> torch .Tensor :
24172418 current_input_ids , current_input_others = self ._sampling_inputs (
24182419 input_ids ,
@@ -2423,7 +2424,7 @@ def _get_current_q_output(
24232424 share_cache_keys = self .shared_cache_keys ,
24242425 )
24252426 output_q = self .block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
2426- return output_q
2427+ return output_q . to ( cache_device )
24272428
24282429 def _get_current_num_elm (
24292430 self ,
@@ -2458,13 +2459,15 @@ def _quantize_block(
24582459 if is_fp8_linear (m ):
24592460 new_layer = convert_fp8_layer_to_linear (m , self .amp_dtype , self .device ).to (device )
24602461 set_module (block , n , new_layer )
2461-
2462- if is_complex_device_mapping (self .device_map ):
2463- set_auto_device_map_for_block_with_tuning (
2462+ # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights
2463+ # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk
2464+ if self .device_map == "auto" or ((isinstance (self .device_map , str ) and "," in self .device_map )):
2465+ card_0_in_high_risk , loss_device = set_auto_device_map_for_block_with_tuning (
24642466 block , self .device_map , input_ids , self .low_gpu_mem_usage , self .batch_size , device
24652467 )
24662468 else :
24672469 block = block .to (device )
2470+ card_0_in_high_risk , loss_device = False , device
24682471
24692472 if is_complex_device_mapping (self .device_map ):
24702473 for n , m in block .named_modules ():
@@ -2594,21 +2597,21 @@ def _quantize_block(
25942597
25952598 current_output = self ._get_current_output (output , indices )
25962599
2597- current_output = to_device (current_output , device )
2600+ current_output = to_device (current_output , loss_device )
25982601
2599- output_q = self ._get_current_q_output (block , input_ids , input_others , indices , device )
2602+ output_q = self ._get_current_q_output (block , input_ids , input_others , indices , device , loss_device )
26002603
26012604 if self .attention_mask :
26022605 tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
2603- tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
2606+ tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (loss_device )
26042607 tmp_attention_mask .unsqueeze_ (- 1 )
26052608 num_elm = torch .sum (tmp_attention_mask ).item ()
26062609 if num_elm == 0 :
26072610 num_elm = 1
26082611 else :
26092612 tmp_attention_mask = 1.0
26102613 if self .amp :
2611- with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2614+ with autocast (device_type = loss_device .split (":" )[0 ], dtype = self .amp_dtype ):
26122615 loss = mse_loss ( # pylint: disable=not-callable
26132616 output_q * tmp_attention_mask , current_output * tmp_attention_mask
26142617 )
@@ -2619,20 +2622,29 @@ def _quantize_block(
26192622 )
26202623
26212624 total_loss += loss .item () / num_elm
2625+
2626+ if self .low_gpu_mem_usage and card_0_in_high_risk :
2627+ # clear memory to avoid OOM due to memory fragmentation
2628+ clear_memory_if_reached_threshold (threshold = 0.5 )
2629+
26222630 self ._scale_loss_and_backward (scaler , loss )
26232631
2632+ if self .low_gpu_mem_usage and card_0_in_high_risk :
2633+ # clear memory to avoid OOM due to memory fragmentation
2634+ clear_memory_if_reached_threshold (threshold = 0.8 )
2635+
26242636 if i == 0 :
26252637 init_loss = total_loss
26262638
26272639 if total_loss < best_loss :
26282640 best_loss = total_loss
26292641 if not self .not_use_best_mse :
2630- best_params = collect_best_params (block )
2642+ best_params = collect_best_params (block , self . cache_device )
26312643 # print(f"get better result at iter {i}, the loss is {total_loss}", flush=True)
26322644
26332645 last_best_iter = i
26342646 if self .not_use_best_mse and i == self .iters - 1 :
2635- best_params = collect_best_params (block )
2647+ best_params = collect_best_params (block , self . cache_device )
26362648
26372649 if not self .not_use_best_mse :
26382650 if 0 < self .dynamic_max_gap <= i - last_best_iter :
@@ -2649,6 +2661,8 @@ def _quantize_block(
26492661 f"layers in the block, loss iter 0: { init_loss :.6f} -> iter { best_iter } : { last_loss :.6f} "
26502662 )
26512663 logger .info (dump_info )
2664+ if self .low_gpu_mem_usage :
2665+ clear_memory () # clear cached memory during training
26522666 if len (unquantized_layer_names ) != 0 :
26532667 logger .info (f"{ unquantized_layer_names } have not been quantized" )
26542668 with torch .no_grad ():
@@ -2659,8 +2673,6 @@ def _quantize_block(
26592673 set_amax_for_all_moe_layers (block , attr_name = "orig_layer.act_max" )
26602674
26612675 if self .enable_quanted_input :
2662- if self .low_gpu_mem_usage :
2663- clear_memory ()
26642676 q_outputs = self ._get_block_outputs (
26652677 block ,
26662678 input_ids ,
@@ -2786,13 +2798,16 @@ def _quantize_blocks(
27862798 modules = [get_module (model , n ) for n in names ]
27872799 m = WrapperMultiblock (modules )
27882800
2801+ m .config = model .config if hasattr (model , "config" ) else None
27892802 q_input , input_ids = quantize_block (
27902803 m ,
27912804 input_ids ,
27922805 input_others ,
27932806 q_input = q_input ,
27942807 device = device ,
27952808 )
2809+ if hasattr (model , "config" ):
2810+ del m .config
27962811 if self .is_packing_immediate :
27972812 from auto_round .export import PACKING_LAYER_WITH_FORMAT
27982813
0 commit comments