Skip to content

[distributed] Add Llama3-70B for distributed inference #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@
# Using model name to identify the model to load, for example "llama2-7b-chat".
# You can change it to other values listed below.
# For details on the name-to-distribution mapping, see README.md or models.json.

# Name : HF distribution name, dtype, and model dimension
NAME_TO_DISTRIBUTION_AND_DTYPE = {
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16),
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16),
"llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16, 4096),
"llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16, 4096),
"llama3-70b": ("meta-llama/Meta-Llama-3-70B-Instruct", torch.bfloat16, 8192),
}


Expand Down Expand Up @@ -314,8 +317,12 @@ def main(args):
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name]
logger.info(f"Using model weights from {distribution} and dtype {model_dtype}")
distribution, model_dtype, model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE[
model_name
]
logger.info(
f"Using model weights from {distribution}, dtype {model_dtype} and model dimension {model_dimension}"
)

# Model-level config
model_config = ModelArgs.from_name(distribution)
Expand All @@ -338,6 +345,7 @@ def main(args):

# Tensor parallel is enabled in this program
tp_degree = world_size // pp_degree
logger.info(f"Using TP degree {tp_degree} and PP degree {pp_degree}")

# Create device mesh
mesh_dimensions = (pp_degree, tp_degree)
Expand Down Expand Up @@ -388,7 +396,6 @@ def main(args):
# sense. Thus it is interchangeable with micro-batch size below.
batch_size = len(prompt)
seqlen_prefill = 1024 # sequence length
dim = 4096 # embedding dimension

# Setup KV caches (after model distribution)
# The number of cache lanes is the same as the maximum number of
Expand Down Expand Up @@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
0, config.vocab_size, (batch_size, seqlen), device=device
)
activation = torch.rand(
batch_size, seqlen, dim, device=device, dtype=model_dtype
batch_size, seqlen, model_dimension, device=device, dtype=model_dtype
)
logits = torch.rand(
batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype
Expand Down
4 changes: 2 additions & 2 deletions torchchat/distributed/force_download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct")
print("Model weights and tokenizer downloaded")
Loading