Skip to content

Commit

Permalink
merge upstream updates and adapt to the MPT fix in torch_spmd (#5)
Browse files Browse the repository at this point in the history
* Set `record_shapes=True` for profiler

ghstack-source-id: 6f1ed49d15ce311f1bf118820965cdb5309a8030
Pull Request resolved: pytorch#419

* Improved `repeat_kv` eager perf

ghstack-source-id: 39e484954814e61cdfb2ba661f0a98c83bc0ce60
Pull Request resolved: pytorch#418

* Adding FSDP Memory Tracking and Estimation

ghstack-source-id: c8ed20fc585957bd164dd963307616a53991615d
Pull Request resolved: pytorch#425

* Adding integration test for FSDP Memory Tracking and Estimation

ghstack-source-id: cc224db8951ec7a133fd769845a4765cbedc6454
Pull Request resolved: pytorch#426

* by default disable heavy memory profiling

ghstack-source-id: cad7b3c41fd60ec19c0e6e7d058e8aa00602a187
Pull Request resolved: pytorch#430

* Add the option to turn on async-TP

ghstack-source-id: 0a03379eeb3a63b2d1ad4dff84d0e61ca82b1bbf
Pull Request resolved: pytorch#429

* Modifying memory estimation options and minor changes

ghstack-source-id: 5f09824cddaed6585cc094095e1e95dd070d76f4
Pull Request resolved: pytorch#435

* add comment pointing to Sequence Parallel optimization example

ghstack-source-id: 6fa0dcd4bca876e10a6a8349283fb940a59ad234
Pull Request resolved: pytorch#438

* switch float8 logic from Float8DynamicLinear to Float8Linear (pytorch#436)

Summary:

After pytorch-labs/float8_experimental#300,
`Float8Linear` with default settings is equivalent to
`Float8DynamicLinear`. This PR changes `torchtitan` to use
`Float8Linear`.

To support the new UX of `float8_experimental` better, I also switched
the `fp8_linear` configuration to be a boolean on whether to swap the
linears or not. In the future we can add new options on how to configure
each linear (scaling type, scaling granularity, etc) - saving that for a
future PR.

Test Plan:

```
// run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs,
// verify performance and loss values do not change meaningfully between
// baseline and this PR

// baseline (before this PR)
// 1. compile, bf16
// 2. compile, float8
// 3. compile, float8, fdsp_fp8_allgather=True
// 4. compile, float8, fdsp_fp8_allgather=True, tp=2
// logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce

// experiment (this PR): repeat all of the above, but with Float8Linear
// logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`

ghstack-source-id: 50b2d0c2b4c22e2f045cafd8630c16f3a8c6d35f
Pull Request resolved: pytorch#444

* Reordered TP parallel plan to follow execution order

ghstack-source-id: b4924952adeb5f16d08b60faa54690762841c422
Pull Request resolved: pytorch#445

* Made some stylistic changes to `apply_dp`

ghstack-source-id: fb78e9eb8aa406ba87d6ad6cf2229c1027dae42f
Pull Request resolved: pytorch#446

* Refactored activation checkpointing

ghstack-source-id: 785c7e47651cda97ea22d0147d14b8d061ce042d
Pull Request resolved: pytorch#447

* compiled RMSNorm

ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: pytorch#442

* Renamed parallel styles for transformer block weights

ghstack-source-id: 5fb0bf3d08cacf27242ec0f85d5dd3cdc03b739e
Pull Request resolved: pytorch#448

* Added type annotations and more stylistic changes

ghstack-source-id: 1bd5b9d5abc8644785132f8eb2baaf8b1cfc5fb5
Pull Request resolved: pytorch#449

---------

Co-authored-by: Andrew Gu <andgu@fb.com>
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@meta.com>
Co-authored-by: Yifu Wang <yifu@fb.com>
Co-authored-by: Vasiliy Kuznetsov <vkuzo@users.noreply.github.com>
  • Loading branch information
5 people authored Jul 11, 2024
1 parent b1f40ad commit d8859ab
Show file tree
Hide file tree
Showing 16 changed files with 467 additions and 188 deletions.
1 change: 1 addition & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ tensorboard
sentencepiece
tiktoken
blobfile
tabulate
220 changes: 220 additions & 0 deletions estimation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import gc
import os

import torch
import torch.nn.functional as F
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed import destroy_process_group
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.distributed.tensor.parallel import loss_parallel
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import create_tokenizer
from torchtitan.float8_linear import build_fp8_linear
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
from train import build_optimizers


def estimate_memory(job_config: JobConfig):
init_logger()
logger.info("Estimating memory usage...")
gc.disable()
gc.collect(1)

# Get the world size
world_size = int(os.environ["WORLD_SIZE"])

# if tp > or pp > 1, we exit
if (
job_config.training.tensor_parallel_degree > 1
or job_config.experimental.pipeline_parallel_degree > 1
):
logger.info(
"Tensor parallelism and pipeline parallelism are not supported yet."
)
return

# fake tensor doesn't work with fused rmsnorm
if (
job_config.model.norm_type == "fused_rmsnorm"
and not job_config.memory_estimation.disable_fake_mode
):
logger.info(
"Fused RMSNorm is not supported yet under fake estimation mode. "
"Switching to rmsnorm."
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
torch.cuda.set_device(device)

# init fake pg
store = FakeStore()
torch.distributed.init_process_group(
"fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store
)

# build meshes
world_mesh = parallel_dims.build_mesh(device_type="cuda")

if not parallel_dims.dp_enabled:
logger.info("Data parallelism is not enabled. Skipping memory estimation.")
return

model_name = job_config.model.name

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# loss_parallel enables dispatching to efficient loss operators
loss_parallel_ctx = (
loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext
)

# loss fn can be shared by pipeline-parallel or non-pp execution
def loss_fn(pred, labels):
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
# 3. max_seq_len base on inputs
model_config.norm_type = job_config.model.norm_type
model_config.vocab_size = tokenizer.n_words
model_config.max_seq_len = job_config.training.seq_len

with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext():

logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# apply fp8 linear module swap
if job_config.training.fp8_linear:
build_fp8_linear(whole_model, job_config)

# apply PT-D DP/TP parallelisms and activation checkpointing
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)

if not active_fake_mode():
whole_model.init_weights()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
torch.randint(
0,
model_config.vocab_size,
(job_config.training.batch_size, model_config.max_seq_len),
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with loss_parallel_ctx():
pred = whole_model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# optimizer step
optimizers.step()
lr_schedulers.step()
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
if iter_idx == 0:
fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state
gc.collect(1)

fsdp_memtracker.display_modulewise_snapshots(
depth=3, units="MiB", tabulate=True
)
mem_stats = torch.cuda.memory_stats()
peak_active = mem_stats["active_bytes.all.peak"]
peak_reserved = mem_stats["reserved_bytes.all.peak"]
num_retries = mem_stats["num_alloc_retries"]
dev = torch.device(torch.cuda.current_device())
tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"]
gib = 1024**3
print(
f"peak active: {peak_active / gib} GiB | peak reserved:"
f" {peak_reserved / gib} GiB | num_retries: {num_retries}"
)
print(f"Tracker Max: {tracker_peak / gib} GiB")
if job_config.memory_estimation.disable_fake_mode and peak_active > 0:
print(f"Tracker Accuracy: {tracker_peak/peak_active}")
gc.enable()


if __name__ == "__main__":
config = JobConfig()
config.parse_args()
try:
estimate_memory(config)
finally:
destroy_process_group()
21 changes: 16 additions & 5 deletions run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ set -ex
# libUV is a scalable backend for TCPStore which is used in processGroup
# rendezvous. This is the recommended backend for distributed training.
export USE_LIBUV=1
TRAINER_DIR=${1:-/home/$USER/local/torchtitan}
TRAINER_DIR=${TRAINER_DIR:-/home/$USER/local/torchtitan}

# use envs as local overrides for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NNODES=${NNODES:-"1"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}
Expand All @@ -29,7 +30,17 @@ if [ $# -ne 0 ]; then
overrides="$*"
fi

# TORCH_TRACE="outputs/compile_trace" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
# Check if --estimate.memory=True is in the arguments
if echo "$overrides" | grep -q -- "--memory_estimation.enabled"; then
# Calculate WORLD_SIZE as the product of NGPU and NNODES
# Export WORLD_SIZE and LOCAL_RANK
export WORLD_SIZE=$((NGPU * NNODES))
export LOCAL_RANK=0
python estimation.py --job.config_file ${CONFIG_FILE} $overrides
else
# Call train.py if not in estimation mode
# TORCH_TRACE="outputs/compile_trace" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
train.py --job.config_file ${CONFIG_FILE} $overrides
fi
10 changes: 10 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,16 @@ def build_test_list():
],
"Fused Optimizer Test",
),
OverrideDefinitions(
[
[
"--memory_estimation.enabled --model.norm_type rmsnorm",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
),
]
return integration_tests_flavors

Expand Down
41 changes: 29 additions & 12 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down Expand Up @@ -249,6 +249,12 @@ def __init__(self):
action="store_true",
help="Whether to use the experimental torch_spmd style parallelism",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
action="store_true",
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_degree",
type=int,
Expand Down Expand Up @@ -341,15 +347,11 @@ def __init__(self):
)
self.parser.add_argument(
"--training.fp8_linear",
type=str,
default="",
choices=[
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
action="store_true",
help="""
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
This features requires you to install 'float8_experimental' which can be found
If true, swaps `torch.nn.Linear` with `Float8Linear` with
default settings (dynamic scaling).
This feature requires you to install 'float8_experimental' which can be found
here: https://github.com/pytorch-labs/float8_experimental
""",
)
Expand Down Expand Up @@ -488,6 +490,20 @@ def __init__(self):
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)

# memory estimation settings
self.parser.add_argument(
"--memory_estimation.enabled",
help="Whether to estimate memory usage for FSDP",
action="store_true",
)

self.parser.add_argument(
"--memory_estimation.disable_fake_mode",
help="Whether to estimate memory under FakeTensorMode",
default=False,
action="store_true",
)

def parse_args(self, args_list: list = sys.argv[1:]):
args, cmd_args = self.parse_args_from_command_line(args_list)
config_file = getattr(args, "job.config_file", None)
Expand Down Expand Up @@ -524,10 +540,11 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
args_dict[first_level_key][second_level_key] = v
return args_dict

def _validate_config(self) -> bool:
def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name and self.model.flavor and self.model.tokenizer_path
return True
assert self.model.name
assert self.model.flavor
assert self.model.tokenizer_path

def parse_args_from_command_line(
self, args_list
Expand Down
Loading

0 comments on commit d8859ab

Please sign in to comment.