Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3.2 vision autotp #6577

Merged
merged 7 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ def get_model_num_kv_heads(self, config):
num_kv_heads = None
# multi_query_group_num is for chatglm2 & chatglm3
kv_head_names = [
'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads'
'multi_query_group_num', 'num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads',
'attention_heads'
]
for name in kv_head_names:
if hasattr(config, name):
Expand Down
10 changes: 9 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,13 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

# 3. Try to get num_key_heads from model_config.num_key_value_heads
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
if hasattr(model_config, "vision_config"):
if "MllamaVisionEncoderLayer" in str(module):
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.vision_config)
else:
num_kv_heads = _autotp.get_model_num_kv_heads(model_config.text_config)
else:
num_kv_heads = _autotp.get_model_num_kv_heads(model_config)

# 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)
Expand Down Expand Up @@ -339,6 +345,8 @@ def set_lm_head(module):
"weight") and not module.embed_out.weight.is_meta and isinstance(
module.embed_out, torch.nn.Linear):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
elif hasattr(module.language_model, "lm_head"):
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
return module

def conv2d_parallel_shard_weights(model, rank, world_size):
Expand Down
Loading