Skip to content

Commit

Permalink
add llama vision support
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Sep 27, 2024
1 parent 2a56f53 commit d4e6eb8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ def get_model_num_kv_heads(self, config):
num_kv_heads = None
# multi_query_group_num is for chatglm2 & chatglm3
kv_head_names = [
'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads'
'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads',
'attention_heads'
]
for name in kv_head_names:
if hasattr(config, name):
Expand Down
10 changes: 9 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,13 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

# 3. Try to get num_key_heads from model_config.num_key_value_heads
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
if hasattr(model_config, "vision_config"):
if "MllamaVisionEncoderLayer" in str(module):
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config)
else:
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config)
else:
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)

# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)
Expand Down Expand Up @@ -339,6 +345,8 @@ def set_lm_head(module):
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
return module

def conv2d_parallel_shard_weights(model, rank, world_size):
Expand Down

0 comments on commit d4e6eb8

Please sign in to comment.