forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
merge upstream updates and adapt to the MPT fix in torch_spmd (#5)
* Set `record_shapes=True` for profiler ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030 Pull Request resolved: pytorch#419 * Improved `repeat_kv` eager perf ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60 Pull Request resolved: pytorch#418 * Adding FSDP Memory Tracking and Estimation ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d Pull Request resolved: pytorch#425 * Adding integration test for FSDP Memory Tracking and Estimation ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454 Pull Request resolved: pytorch#426 * by default disable heavy memory profiling ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187 Pull Request resolved: pytorch#430 * Add the option to turn on async-TP ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf Pull Request resolved: pytorch#429 * Modifying memory estimation options and minor changes ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4 Pull Request resolved: pytorch#435 * add comment pointing to Sequence Parallel optimization example ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234 Pull Request resolved: pytorch#438 * switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436) Summary: After pytorch-labs/float8_experimental#300, `Float8Linear` with default settings is equivalent to `Float8DynamicLinear`. This PR changes `torchtitan` to use `Float8Linear`. To support the new UX of `float8_experimental` better, I also switched the `fp8_linear` configuration to be a boolean on whether to swap the linears or not. In the future we can add new options on how to configure each linear (scaling type, scaling granularity, etc) - saving that for a future PR. Test Plan: ``` // run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs, // verify performance and loss values do not change meaningfully between // baseline and this PR // baseline (before this PR) // 1. compile, bf16 // 2. compile, float8 // 3. compile, float8, fdsp_fp8_allgather=True // 4. compile, float8, fdsp_fp8_allgather=True, tp=2 // logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce // experiment (this PR): repeat all of the above, but with Float8Linear // logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631 ``` Reviewers: Subscribers: Tasks: Tags: * Removed `_experimental_support_context_fn_in_torch_utils_checkpoint` ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f Pull Request resolved: pytorch#444 * Reordered TP parallel plan to follow execution order ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422 Pull Request resolved: pytorch#445 * Made some stylistic changes to `apply_dp` ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f Pull Request resolved: pytorch#446 * Refactored activation checkpointing ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d Pull Request resolved: pytorch#447 * compiled RMSNorm ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3 Pull Request resolved: pytorch#442 * Renamed parallel styles for transformer block weights ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e Pull Request resolved: pytorch#448 * Added type annotations and more stylistic changes ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5 Pull Request resolved: pytorch#449 --------- Co-authored-by: Andrew Gu <andgu@fb.com> Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com> Co-authored-by: Yifu Wang <yifu@fb.com> Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
- Loading branch information
1 parent
b1f40ad
commit d8859ab
Showing
16 changed files
with
467 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ tensorboard | |
sentencepiece | ||
tiktoken | ||
blobfile | ||
tabulate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import contextlib | ||
import gc | ||
import os | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch._guards import active_fake_mode | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
from torch.distributed import destroy_process_group | ||
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker | ||
from torch.distributed.tensor.parallel import loss_parallel | ||
from torch.testing._internal.distributed.fake_pg import FakeStore | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.datasets import create_tokenizer | ||
from torchtitan.float8_linear import build_fp8_linear | ||
from torchtitan.logging_utils import init_logger, logger | ||
from torchtitan.lr_scheduling import get_lr_schedulers | ||
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config | ||
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims | ||
from train import build_optimizers | ||
|
||
|
||
def estimate_memory(job_config: JobConfig): | ||
init_logger() | ||
logger.info("Estimating memory usage...") | ||
gc.disable() | ||
gc.collect(1) | ||
|
||
# Get the world size | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
|
||
# if tp > or pp > 1, we exit | ||
if ( | ||
job_config.training.tensor_parallel_degree > 1 | ||
or job_config.experimental.pipeline_parallel_degree > 1 | ||
): | ||
logger.info( | ||
"Tensor parallelism and pipeline parallelism are not supported yet." | ||
) | ||
return | ||
|
||
# fake tensor doesn't work with fused rmsnorm | ||
if ( | ||
job_config.model.norm_type == "fused_rmsnorm" | ||
and not job_config.memory_estimation.disable_fake_mode | ||
): | ||
logger.info( | ||
"Fused RMSNorm is not supported yet under fake estimation mode. " | ||
"Switching to rmsnorm." | ||
) | ||
job_config.model.norm_type = "rmsnorm" | ||
|
||
if job_config.model.norm_type == "compiled_rmsnorm": | ||
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") | ||
job_config.model.norm_type = "rmsnorm" | ||
|
||
if job_config.training.compile: | ||
logger.info("Compile mode is not supported yet. Switching to eager mode.") | ||
job_config.training.compile = False | ||
|
||
parallel_dims = ParallelDims( | ||
dp=job_config.training.data_parallel_degree, | ||
tp=job_config.training.tensor_parallel_degree, | ||
pp=job_config.experimental.pipeline_parallel_degree, | ||
world_size=world_size, | ||
enable_loss_parallel=job_config.training.enable_loss_parallel, | ||
) | ||
|
||
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") | ||
torch.cuda.set_device(device) | ||
|
||
# init fake pg | ||
store = FakeStore() | ||
torch.distributed.init_process_group( | ||
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store | ||
) | ||
|
||
# build meshes | ||
world_mesh = parallel_dims.build_mesh(device_type="cuda") | ||
|
||
if not parallel_dims.dp_enabled: | ||
logger.info("Data parallelism is not enabled. Skipping memory estimation.") | ||
return | ||
|
||
model_name = job_config.model.name | ||
|
||
# build tokenizer | ||
tokenizer_type = model_name_to_tokenizer[model_name] | ||
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) | ||
|
||
# loss_parallel enables dispatching to efficient loss operators | ||
loss_parallel_ctx = ( | ||
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext | ||
) | ||
|
||
# loss fn can be shared by pipeline-parallel or non-pp execution | ||
def loss_fn(pred, labels): | ||
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) | ||
|
||
# build model (using meta init) | ||
model_cls = model_name_to_cls[model_name] | ||
model_config = models_config[model_name][job_config.model.flavor] | ||
# set the model configs from training inputs: | ||
# 1. norm type to decide which norm layer to use | ||
# 2. vocab size from tokenizer | ||
# 3. max_seq_len base on inputs | ||
model_config.norm_type = job_config.model.norm_type | ||
model_config.vocab_size = tokenizer.n_words | ||
model_config.max_seq_len = job_config.training.seq_len | ||
|
||
with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext(): | ||
|
||
logger.info( | ||
f"Building {model_name} {job_config.model.flavor} with {model_config}" | ||
) | ||
with torch.device("meta"): | ||
whole_model = model_cls.from_model_args(model_config) | ||
|
||
# apply fp8 linear module swap | ||
if job_config.training.fp8_linear: | ||
build_fp8_linear(whole_model, job_config) | ||
|
||
# apply PT-D DP/TP parallelisms and activation checkpointing | ||
model_parts = [whole_model] | ||
model_parts = [ | ||
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) | ||
for m in model_parts | ||
] | ||
|
||
init_device = "cuda" | ||
for model in model_parts: | ||
model.to_empty(device=init_device) | ||
|
||
if not active_fake_mode(): | ||
whole_model.init_weights() | ||
|
||
# build optimizer after applying parallelisms to the model | ||
optimizers = build_optimizers(model_parts, job_config) | ||
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) | ||
|
||
for model in model_parts: | ||
model.train() | ||
logger.info(f"Vocab size: {model_config.vocab_size}") | ||
# Create a dummy batch instead of loading from a dataset | ||
batch = ( | ||
torch.randint( | ||
0, | ||
model_config.vocab_size, | ||
(job_config.training.batch_size, model_config.max_seq_len), | ||
device="cuda", | ||
), | ||
torch.randint( | ||
0, | ||
model_config.vocab_size, | ||
(job_config.training.batch_size, model_config.max_seq_len), | ||
device="cuda", | ||
), | ||
) | ||
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0]) | ||
fsdp_memtracker.track_inputs(batch) | ||
|
||
with fsdp_memtracker: | ||
for iter_idx in range(2): | ||
input_ids, labels = batch | ||
# train step | ||
with loss_parallel_ctx(): | ||
pred = whole_model(input_ids) | ||
loss = loss_fn(pred, labels) | ||
del pred | ||
loss.backward() | ||
|
||
# clip gradients | ||
for model in model_parts: | ||
torch.nn.utils.clip_grad_norm_( | ||
model.parameters(), job_config.training.max_norm, foreach=True | ||
) | ||
# optimizer step | ||
optimizers.step() | ||
lr_schedulers.step() | ||
optimizers.zero_grad() | ||
print(f"Peak Memory at iter: {iter_idx}") | ||
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) | ||
if iter_idx == 0: | ||
fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state | ||
gc.collect(1) | ||
|
||
fsdp_memtracker.display_modulewise_snapshots( | ||
depth=3, units="MiB", tabulate=True | ||
) | ||
mem_stats = torch.cuda.memory_stats() | ||
peak_active = mem_stats["active_bytes.all.peak"] | ||
peak_reserved = mem_stats["reserved_bytes.all.peak"] | ||
num_retries = mem_stats["num_alloc_retries"] | ||
dev = torch.device(torch.cuda.current_device()) | ||
tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"] | ||
gib = 1024**3 | ||
print( | ||
f"peak active: {peak_active / gib} GiB | peak reserved:" | ||
f" {peak_reserved / gib} GiB | num_retries: {num_retries}" | ||
) | ||
print(f"Tracker Max: {tracker_peak / gib} GiB") | ||
if job_config.memory_estimation.disable_fake_mode and peak_active > 0: | ||
print(f"Tracker Accuracy: {tracker_peak/peak_active}") | ||
gc.enable() | ||
|
||
|
||
if __name__ == "__main__": | ||
config = JobConfig() | ||
config.parse_args() | ||
try: | ||
estimate_memory(config) | ||
finally: | ||
destroy_process_group() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.