From e97b453645a03bc6901d74ec13be1c5d7f1a1fec Mon Sep 17 00:00:00 2001 From: Yejing-Lai Date: Wed, 9 Oct 2024 02:16:04 +0800 Subject: [PATCH] Add llama3.2 vision autotp (#6577) Llama3.2-11b and llama3.2-90b including vision model and text model, these two models have different num_kv_heads, so we need to set num_kv_heads dynamically. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/module_inject/auto_tp.py | 3 ++- deepspeed/module_inject/replace_module.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index e6eea2183de5..52d7c95ec9d8 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -495,7 +495,8 @@ def get_model_num_kv_heads(self, config): num_kv_heads = None # multi_query_group_num is for chatglm2 & chatglm3 kv_head_names = [ - 'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads' + 'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads', + 'attention_heads' ] for name in kv_head_names: if hasattr(config, name): diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 64dc5479940c..cf70c4530c82 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -274,7 +274,13 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) # 3. Try to get num_key_heads from model_config.num_key_value_heads - num_kv_heads = _autotp.get_model_num_kv_heads(model_config) + if hasattr(model_config, "vision_config"): + if "MllamaVisionEncoderLayer" in str(module): + num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config) + else: + num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config) + else: + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division set_num_kv_heads(num_kv_heads) @@ -339,6 +345,8 @@ def set_lm_head(module): "weight") and not module.embed_out.weight.is_meta and isinstance( module.embed_out, torch.nn.Linear): module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") + elif hasattr(module.language_model, "lm_head"): + module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head") return module def conv2d_parallel_shard_weights(model, rank, world_size):