@@ -746,6 +746,7 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float:
746746 Returns:
747747 float: Memory ratio (num_experts_per_tok / num_experts).
748748 Returns 1.0 for non-MoE models.
749+ bool: True if the model is MoE, False otherwise.
749750
750751 Examples:
751752 - Non-MoE model: returns 1.0
@@ -795,10 +796,10 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float:
795796 f"activation memory ratio: { moe_ratio :.2f} "
796797 )
797798 logger .debug (f"Using MoE memory ratio: { moe_ratio :.4f} " )
798- return moe_ratio
799+ return moe_ratio , True
799800 break # Only check once per block
800801
801- return 1.0 # Default ratio for non-MoE models
802+ return 1.0 , False # Default ratio for non-MoE models
802803
803804
804805def estimate_tuning_block_mem (
@@ -829,7 +830,7 @@ def estimate_tuning_block_mem(
829830 seq_len = input_ids [0 ].shape [1 ] if input_ids and len (input_ids [0 ].shape ) >= 2 else 1
830831 element_size = input_ids [0 ].element_size () if input_ids else 2 # Default to 2 bytes (fp16/bf16)
831832
832- moe_ratio = get_moe_memory_ratio (block ) # Get MoE memory ratio (1.0 for non-MoE models)
833+ moe_ratio , has_moe = get_moe_memory_ratio (block ) # Get MoE memory ratio (1.0 for non-MoE models)
833834
834835 for name , module in block .named_modules ():
835836 if check_to_quantized (module ):
@@ -856,7 +857,8 @@ def estimate_tuning_block_mem(
856857
857858 # memory * 2, because it contains grad tensor.
858859 # Check if this is a MoE expert layer by layer name (e.g., "mlp.experts.0.gate_proj")
859- is_moe_expert = "expert" in layer_name .lower ()
860+ parent_module = get_module (block , layer_name .rsplit ("." , 1 )[0 ]) if "." in layer_name else block
861+ is_moe_expert = "expert" in layer_name .lower () and isinstance (parent_module , torch .nn .ModuleList )
860862 layer_memory_dict [layer_name ] = {
861863 "param_memory" : param_memory_gb * 2 ,
862864 "output_memory" : output_memory_gb * 2 ,
@@ -873,22 +875,28 @@ def estimate_tuning_block_mem(
873875 for layer_name , info in layer_memory_dict .items ():
874876 if info .get ("is_moe_expert" , False ):
875877 # MoE expert layer: only a fraction of experts are active
876- # memory * 2 records intermediate activation memory of GeLU, or etc.
877- layer_activation_memory += info ["output_memory" ] * moe_ratio * 2
878+ layer_activation_memory += info ["output_memory" ] * moe_ratio
878879 else :
879880 # Non-MoE layer: use full activation memory
880881 layer_activation_memory += info ["output_memory" ]
881882
882883 # layer_activation_memory considers other ops activation memory
883884 # 1GB considers norm weight, sdpa, reference_output, etc.
884885 additional_memory = layer_activation_memory + 1 # GB
886+ if has_moe :
887+ # TODO: Cannot estimate the memory usage correctly for MoE models yet.
888+ # For MoE models, additional memory usage can be higher due to routing, gating,
889+ # and multiple expert activations. Here we use a conservative estimate.
890+ moe_additional_memory = additional_memory * 3 # GB
891+ additional_memory += moe_additional_memory
885892 if torch .xpu .is_available ():
886893 # https://github.com/intel/torch-xpu-ops/issues/2232
887894 # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB
888895 xpu_additional_memory = 12 # GB
889896 additional_memory += xpu_additional_memory
890- logger .warning_once ("XPU additional memory usage of SDPA is estimated to be 12 GB." )
891- logger .warning_once ("Remove it after https://github.com/intel/torch-xpu-ops/issues/2232 is fixed." )
897+ logger .warning_once (
898+ "[Memory Estimation]: If there is an abnormal memory issue, please collect log with AR_LOG_LEVEL=debug and raise issue to us."
899+ )
892900
893901 return layer_memory_dict , layer_activation_memory , block_input_output_memory , additional_memory
894902
@@ -947,20 +955,17 @@ def set_auto_device_map_for_block_with_tuning(
947955 if low_gpu_mem_usage :
948956 block_input_output_memory = 0
949957
950- # Calculate total block memory from layer memory dict (including both param and output memory)
951958 total_block_param_memory = sum (info ["param_memory" ] for info in layer_memory_dict .values ())
952- total_block_output_memory = layer_activation_memory
953959
954960 # Average dispatch strategy
955961 # card_0_left_memory = card_0_mem - block_input_output_memory - additional_memory - layer_outputs_memory
956- logger .debug ("Card 0 used memory details [Estimated]:" )
962+ card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory
963+ logger .debug (f"Card 0 used memory details [Estimated]: { card_0_used_memory } GB" )
957964 logger .debug (f" Block input output cache memory: { block_input_output_memory } GB" )
958965 logger .debug (f" Quantized layer outputs memory: { layer_activation_memory } GB" )
959966 logger .debug (f" Additional_memory from other ops: { additional_memory } GB" )
960967
961- card_0_left_memory = max (
962- 0 , device_0_memory - block_input_output_memory - total_block_output_memory - additional_memory
963- )
968+ card_0_left_memory = max (0 , (device_0_memory - card_0_used_memory ))
964969
965970 # Calculate total available memory across all devices
966971 total_available_memory = card_0_left_memory
0 commit comments