Skip to content

Commit d8917f9

Browse files
committed
add warning for memory estimation
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent 29df357 commit d8917f9

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

auto_round/utils/device.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float:
746746
Returns:
747747
float: Memory ratio (num_experts_per_tok / num_experts).
748748
Returns 1.0 for non-MoE models.
749+
bool: True if the model is MoE, False otherwise.
749750
750751
Examples:
751752
- Non-MoE model: returns 1.0
@@ -795,10 +796,10 @@ def get_moe_memory_ratio(block: torch.nn.Module) -> float:
795796
f"activation memory ratio: {moe_ratio:.2f}"
796797
)
797798
logger.debug(f"Using MoE memory ratio: {moe_ratio:.4f}")
798-
return moe_ratio
799+
return moe_ratio, True
799800
break # Only check once per block
800801

801-
return 1.0 # Default ratio for non-MoE models
802+
return 1.0, False # Default ratio for non-MoE models
802803

803804

804805
def estimate_tuning_block_mem(
@@ -829,7 +830,7 @@ def estimate_tuning_block_mem(
829830
seq_len = input_ids[0].shape[1] if input_ids and len(input_ids[0].shape) >= 2 else 1
830831
element_size = input_ids[0].element_size() if input_ids else 2 # Default to 2 bytes (fp16/bf16)
831832

832-
moe_ratio = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models)
833+
moe_ratio, has_moe = get_moe_memory_ratio(block) # Get MoE memory ratio (1.0 for non-MoE models)
833834

834835
for name, module in block.named_modules():
835836
if check_to_quantized(module):
@@ -856,7 +857,8 @@ def estimate_tuning_block_mem(
856857

857858
# memory * 2, because it contains grad tensor.
858859
# Check if this is a MoE expert layer by layer name (e.g., "mlp.experts.0.gate_proj")
859-
is_moe_expert = "expert" in layer_name.lower()
860+
parent_module = get_module(block, layer_name.rsplit(".", 1)[0]) if "." in layer_name else block
861+
is_moe_expert = "expert" in layer_name.lower() and isinstance(parent_module, torch.nn.ModuleList)
860862
layer_memory_dict[layer_name] = {
861863
"param_memory": param_memory_gb * 2,
862864
"output_memory": output_memory_gb * 2,
@@ -873,22 +875,28 @@ def estimate_tuning_block_mem(
873875
for layer_name, info in layer_memory_dict.items():
874876
if info.get("is_moe_expert", False):
875877
# MoE expert layer: only a fraction of experts are active
876-
# memory * 2 records intermediate activation memory of GeLU, or etc.
877-
layer_activation_memory += info["output_memory"] * moe_ratio * 2
878+
layer_activation_memory += info["output_memory"] * moe_ratio
878879
else:
879880
# Non-MoE layer: use full activation memory
880881
layer_activation_memory += info["output_memory"]
881882

882883
# layer_activation_memory considers other ops activation memory
883884
# 1GB considers norm weight, sdpa, reference_output, etc.
884885
additional_memory = layer_activation_memory + 1 # GB
886+
if has_moe:
887+
# TODO: Cannot estimate the memory usage correctly for MoE models yet.
888+
# For MoE models, additional memory usage can be higher due to routing, gating,
889+
# and multiple expert activations. Here we use a conservative estimate.
890+
moe_additional_memory = additional_memory * 3 # GB
891+
additional_memory += moe_additional_memory
885892
if torch.xpu.is_available():
886893
# https://github.com/intel/torch-xpu-ops/issues/2232
887894
# TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB
888895
xpu_additional_memory = 12 # GB
889896
additional_memory += xpu_additional_memory
890-
logger.warning_once("XPU additional memory usage of SDPA is estimated to be 12 GB.")
891-
logger.warning_once("Remove it after https://github.com/intel/torch-xpu-ops/issues/2232 is fixed.")
897+
logger.warning_once(
898+
"[Memory Estimation]: If there is an abnormal memory issue, please collect log with AR_LOG_LEVEL=debug and raise issue to us."
899+
)
892900

893901
return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory
894902

@@ -947,20 +955,17 @@ def set_auto_device_map_for_block_with_tuning(
947955
if low_gpu_mem_usage:
948956
block_input_output_memory = 0
949957

950-
# Calculate total block memory from layer memory dict (including both param and output memory)
951958
total_block_param_memory = sum(info["param_memory"] for info in layer_memory_dict.values())
952-
total_block_output_memory = layer_activation_memory
953959

954960
# Average dispatch strategy
955961
# card_0_left_memory = card_0_mem - block_input_output_memory - additional_memory - layer_outputs_memory
956-
logger.debug("Card 0 used memory details [Estimated]:")
962+
card_0_used_memory = block_input_output_memory + layer_activation_memory + additional_memory
963+
logger.debug(f"Card 0 used memory details [Estimated]: {card_0_used_memory} GB")
957964
logger.debug(f" Block input output cache memory: {block_input_output_memory} GB")
958965
logger.debug(f" Quantized layer outputs memory: {layer_activation_memory} GB")
959966
logger.debug(f" Additional_memory from other ops: {additional_memory} GB")
960967

961-
card_0_left_memory = max(
962-
0, device_0_memory - block_input_output_memory - total_block_output_memory - additional_memory
963-
)
968+
card_0_left_memory = max(0, (device_0_memory - card_0_used_memory))
964969

965970
# Calculate total available memory across all devices
966971
total_available_memory = card_0_left_memory

0 commit comments

Comments
 (0)