Skip to content

Commit 9045f08

Browse files
Add revised benchmarking logic and results (#9)
* Revised estimation of batch count, directly retrieving from len(train_dataloader). Deleted unused timer_handle argument in Trainer. Revised handling of "max_seq_len" override in benchmarking. Added support for automatic switching between lora and full-rank sharding scheme in benchmarking. * Revised handling of unspecified max_seq_length. Added llama-3 to benchmark model_list. * Benchmarking: Revised benchmark script to ensure consistent per-device train batch size. * Benchmarking: replaced trainer.step with trainer.train_step to avoid eval overhead in benchmarking. Revised benchmark parsing logic; display optimal batch size for each context width value. * Benchmarking: Updated reference throughput based on updated logic. * Benchmarking: Updated reference throughput descriptions.
1 parent ce1eaa3 commit 9045f08

9 files changed

+138
-81
lines changed

docs/reference_throughput.md

+29-26
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,36 @@
11
# Reference Throughput
22

33
We've benchmarked VectorLM on the Vaughan cluster for a number of model architectures across a variety of node configurations.
4-
In experiments labelled as LoRA, we set hidden dimension to 8. During the testing, the NVIDIA driver version was 525.105.17, CUDA Runtime 12.1.105, and torch 2.2.2.
4+
In experiments labelled as LoRA, we set hidden dimension to 8. Below are version numbers of the testing environment:
55

6-
For consistency, we use a batch size of 8 and the maximum context length that the pre-trained LLM supports, capped at 65536. Note that especially for smaller models, it might be possible to further increase throughput by switching to a larger batch size.
6+
```bash
7+
$ pip3 freeze|grep -E "(torch|flash-attn|nvidia)"
8+
flash-attn==2.5.8
9+
nvidia-cublas-cu12==12.1.3.1
10+
nvidia-cuda-cupti-cu12==12.1.105
11+
nvidia-cuda-nvrtc-cu12==12.1.105
12+
nvidia-cuda-runtime-cu12==12.1.105
13+
nvidia-cudnn-cu12==8.9.2.26
14+
nvidia-cufft-cu12==11.0.2.54
15+
nvidia-curand-cu12==10.3.2.106
16+
nvidia-cusolver-cu12==11.4.5.107
17+
nvidia-cusparse-cu12==12.1.0.106
18+
nvidia-ml-py==12.550.52
19+
nvidia-nccl-cu12==2.19.3
20+
nvidia-nvjitlink-cu12==12.3.101
21+
nvidia-nvtx-cu12==12.1.105
22+
torch==2.2.1
23+
```
724

8-
Entries that read NaN represent combinations where the node configuration does not have enough GPU memory for the training run to complete. An exception is gemma-2b, which currently does not support full-rank FSDP fine-tuning.
25+
For each context width and hardware configuration, we experiment with a per-device batch size of 2, 4, and 8. In the table below, we report the batch size that maximizes training throughput. All values in the table represent the median training throughput in tokens/second across all training steps, aggregated across all GPU devices.
926

10-
All values in the table below represent the median training throughput in tokens per second across all training steps, aggregated across all GPU devices.
27+
| | Meta-Llama-3-8B (2048) | Meta-Llama-3-8B (4096) | Meta-Llama-3-8B (8192) |
28+
| :----------------------------------- | :--------------------- | :--------------------- | :--------------------- |
29+
| (full_rank) NVIDIA A100-SXM4-80GB x1 | 3550.48 (batch: 8) | 3461.64 (batch: 4) | 3204.21 (batch: 2) |
30+
| (full_rank) NVIDIA A100-SXM4-80GB x2 | 6346.00 (batch: 8) | 6182.59 (batch: 4) | 5772.91 (batch: 2) |
31+
| (full_rank) NVIDIA A100-SXM4-80GB x4 | 12688.44 (batch: 8) | 12249.74 (batch: 4) | 11463.46 (batch: 2) |
32+
| (lora) NVIDIA A100-SXM4-80GB x1 | 4079.28 (batch: 8) | 3682.15 (batch: 4) | 3528.93 (batch: 2) |
33+
| (lora) NVIDIA A100-SXM4-80GB x2 | 7182.97 (batch: 8) | 6955.58 (batch: 4) | 6452.96 (batch: 2) |
34+
| (lora) NVIDIA A100-SXM4-80GB x4 | 14299.47 (batch: 8) | 13834.43 (batch: 4) | 12769.23 (batch: 2) |
1135

12-
| | Llama-2-13b-hf | Llama-2-7b-hf | Mistral-7B-v0.1 | Mixtral-8x7B-Instruct-v0.1 | gemma-2b | opt-350m |
13-
| :----------------------------------- | -------------: | ------------: | --------------: | -------------------------: | -------: | -------: |
14-
| (full_rank) NVIDIA A100-SXM4-80GB x1 | 424.726 | 570.818 | 528.747 | nan | nan | 780.045 |
15-
| (full_rank) NVIDIA A100-SXM4-80GB x2 | 660.355 | 919.19 | 794.566 | 275.459 | nan | 1227.67 |
16-
| (full_rank) NVIDIA A100-SXM4-80GB x4 | 1309.4 | 1744.39 | 1577.09 | 817.162 | nan | 2181.46 |
17-
| (full_rank) NVIDIA A40 x1 | nan | 47.6435 | 107.503 | nan | nan | 666.881 |
18-
| (full_rank) NVIDIA A40 x2 | nan | 313.074 | 322.624 | nan | nan | 854.672 |
19-
| (full_rank) NVIDIA A40 x4 | 345.96 | 570.977 | 553.658 | nan | nan | 1765.49 |
20-
| (full_rank) Tesla T4 x1 | nan | nan | nan | nan | nan | 475.51 |
21-
| (full_rank) Tesla T4 x2 | nan | nan | nan | nan | nan | 768.008 |
22-
| (full_rank) Tesla T4 x4 | nan | nan | nan | nan | nan | 1383.6 |
23-
| (full_rank) Tesla T4 x8 | nan | nan | nan | nan | nan | 2414.68 |
24-
| (lora) NVIDIA A100-SXM4-80GB x1 | 560.167 | 646.801 | 525.802 | nan | 851.678 | 859.379 |
25-
| (lora) NVIDIA A100-SXM4-80GB x2 | 871.993 | 1157.17 | 1105.68 | 239.431 | 1724.57 | 1463.82 |
26-
| (lora) NVIDIA A100-SXM4-80GB x4 | 1783.53 | 2091.03 | 2150.06 | 1309.74 | 2719.24 | 2381.01 |
27-
| (lora) NVIDIA A40 x1 | 272.931 | 435.386 | 336.507 | nan | 983.256 | 652.611 |
28-
| (lora) NVIDIA A40 x2 | 105.442 | 457.183 | 356.263 | nan | 725.723 | 1136.17 |
29-
| (lora) NVIDIA A40 x4 | 543.22 | 715.416 | 642.642 | nan | 1302.62 | 1647.57 |
30-
| (lora) Tesla T4 x1 | nan | nan | nan | nan | 148.272 | 571.471 |
31-
| (lora) Tesla T4 x2 | nan | 101.126 | 102.859 | nan | 256.534 | 811.159 |
32-
| (lora) Tesla T4 x4 | nan | 188.575 | 190.127 | nan | 495.755 | 1506.05 |
33-
| (lora) Tesla T4 x8 | 196.709 | 372.375 | 351.361 | nan | 897.81 | 2945.86 |
36+
We provide the tools for evaluating the throughput on different context windows and different hardware/model configuration. Refer to the profiling folder in this repository to get started.

profiling/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ $ python3 launch_benchmark.py
1313
# to accept and automatically invoke the commands.
1414
```
1515

16-
After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results:
16+
After the SLURM jobs complete, profiler output can be found under `data/benchmark`. Invoke the following the to generate a Markdown summary of the results. If the benchmark results include multiple different batch sizes for each (model, context window, hardware) pair, the table would list the "optimal" batch size associated with the highest training throughput for this combination.
1717

1818
```bash
1919
$ python3 profiling/parse_benchmark.py --folder data/benchmark

profiling/benchmark.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from vectorlm.utils.model_utils import (
2626
get_lora_model_from_base_model,
2727
get_submodule_by_pattern,
28-
hook_activation_checkpointing,
2928
load_model_and_tokenizer,
3029
shard_model,
3130
)
@@ -67,7 +66,7 @@ def parse_args() -> Namespace:
6766
default=1000,
6867
)
6968
parser.add_argument("--max_length", type=int)
70-
parser.add_argument("--training_batch_size", type=int)
69+
parser.add_argument("--per_device_batch_size", type=int)
7170
return parser.parse_args()
7271

7372

@@ -273,9 +272,26 @@ def load_datasets(self) -> None:
273272

274273
setup(config.train_parameters.output_dir)
275274

276-
if args.training_batch_size is not None:
277-
config.dataset.train_bs = args.training_batch_size
278-
write_metrics("training_batch_size", args.training_batch_size)
275+
training_args = config.train_parameters
276+
277+
# set a seed
278+
set_seed(training_args.seed)
279+
280+
# set CUDA related dependencies
281+
local_rank = int(os.environ["LOCAL_RANK"])
282+
rank = int(os.environ["RANK"])
283+
world_size = int(os.environ["WORLD_SIZE"])
284+
285+
if args.per_device_batch_size is not None:
286+
config.dataset.train_bs = args.per_device_batch_size
287+
config.dataset.eval_bs = args.per_device_batch_size
288+
289+
write_metrics("training_batch_size", config.dataset.train_bs)
290+
write_metrics("eval_batch_size", config.dataset.eval_bs)
291+
write_metrics(
292+
"training_batch_size_global",
293+
config.dataset.train_bs * world_size,
294+
)
279295

280296
print(f"Writing metrics to {output_path}")
281297
write_metrics("model_name", args.model_name)
@@ -291,16 +307,6 @@ def load_datasets(self) -> None:
291307
repeat=2,
292308
)
293309

294-
training_args = config.train_parameters
295-
296-
# set a seed
297-
set_seed(training_args.seed)
298-
299-
# set CUDA related dependencies
300-
local_rank = int(os.environ["LOCAL_RANK"])
301-
rank = int(os.environ["RANK"])
302-
world_size = int(os.environ["WORLD_SIZE"])
303-
304310
with track_time("dist_init"):
305311
print(f"Rank: {rank}, World size: {world_size}")
306312
if dist.is_initialized():
@@ -314,17 +320,18 @@ def load_datasets(self) -> None:
314320

315321
# load model and tokenizer
316322
lora_peft_config = config.train_parameters.get("lora_peft_config")
323+
is_lora_enabled = lora_peft_config is not None
317324

318325
with track_time("model_load"):
319326
model, tokenizer = load_model_and_tokenizer(
320327
args.model_name,
321328
training_args.use_mp,
322329
get_is_flash_attention_supported(),
323-
training_args.max_seq_len,
330+
args.max_length,
324331
local_rank,
325332
training_args.low_cpu_mem_usage,
326333
)
327-
if lora_peft_config is not None:
334+
if is_lora_enabled:
328335
print("Enabling LoRA Wrapper.")
329336
write_metrics("peft_method", "lora")
330337
model = get_lora_model_from_base_model(model, lora_peft_config)
@@ -348,12 +355,9 @@ def load_datasets(self) -> None:
348355
training_args.sharding_strategy,
349356
local_rank,
350357
training_args.low_cpu_mem_usage,
358+
is_lora_enabled=is_lora_enabled,
351359
)
352360

353-
with track_time("set_activation_checkpointing"):
354-
if training_args.use_activation_checkpointing:
355-
hook_activation_checkpointing(model, decoder_layer_module)
356-
357361
# load dataset
358362
with track_time("dataset_load"):
359363
dataset = BenchmarkingDataset(
@@ -364,14 +368,17 @@ def load_datasets(self) -> None:
364368
max_length=args.max_length,
365369
)
366370

371+
print(
372+
f"Sequence length: {dataset.max_length};"
373+
f"Batch Size (per device): {config.dataset.train_bs}",
374+
)
367375
write_metrics("max_length", dataset.max_length)
368376

369377
# instantiate trainer
370378
trainer = Trainer(
371379
config=training_args,
372380
enable_wandb_logging=config.enable_wandb_logging,
373381
original_dataset_length=dataset.original_length,
374-
timer_handle=track_time,
375382
)
376383

377384
# load optimizer
@@ -412,15 +419,18 @@ def load_datasets(self) -> None:
412419
trainer.model.train()
413420
train_dl_iterator = iter(dataset.train_dataloader)
414421
for _ in tqdm(
415-
range(args.num_train_examples),
422+
range(len(dataset.train_dataloader)),
416423
disable=rank != 0,
417424
file=sys.__stdout__,
418425
):
419426
batch = next(train_dl_iterator)
420427
num_tokens = len(batch["input_ids"].flatten())
421428

422-
with track_time("train_step", {"num_tokens": num_tokens}):
423-
trainer.step(batch, epoch)
429+
with track_time(
430+
"train_step",
431+
{"num_tokens": num_tokens * world_size},
432+
):
433+
trainer.train_step(batch, epoch)
424434

425435
profile_handle.step()
426436
write_metrics(

profiling/configs/benchmark.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ wandb_config:
66

77
train_parameters:
88
output_dir: /dev/shm/lora-benchmark
9-
max_seq_len: 128
109
epochs: 1
1110
seed: 11
1211

profiling/configs/lora-benchmark.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ wandb_config:
66

77
train_parameters:
88
output_dir: /dev/shm/lora-benchmark
9-
max_seq_len: 128
109
epochs: 1
1110
seed: 11
1211

profiling/launch_benchmark.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
model_list = [
2323
"/model-weights/" + model_name
2424
for model_name in [
25-
"opt-350m",
26-
"gemma-2b",
27-
"Llama-2-7b-hf",
28-
"Llama-2-13b-hf",
29-
"Mistral-7B-v0.1",
30-
"Mixtral-8x7B-Instruct-v0.1",
25+
# "opt-350m",
26+
# "gemma-2b",
27+
# "Llama-2-7b-hf",
28+
"Meta-Llama-3-8B",
29+
# "Llama-2-13b-hf",
30+
# "Mistral-7B-v0.1",
31+
# "Mixtral-8x7B-Instruct-v0.1",
3132
]
3233
]
3334

@@ -37,27 +38,28 @@
3738
]
3839

3940
# Set to (-1) to fall back to the max context length of the pre-trained model.
40-
max_length_list = [1024, 2048, 4096, -1]
41-
batch_size = [8, 16, 32, 64, 128]
41+
max_length_list = [8192, 4096, 2048]
42+
# Per-device batch size for training
43+
per_device_batch_size = [2, 4, 8]
4244

4345
slurm_flags_options = {
4446
"nodes": [1],
4547
"mem-per-gpu": ["16GB"],
4648
"ntasks-per-node": [1],
4749
"cpus-per-gpu": [3],
48-
"gres": [f"gpu:{n}" for n in [1, 2, 4, 8]],
50+
"gres": [f"gpu:{n}" for n in [4, 2, 1]],
4951
"partition": partitions,
5052
}
5153

52-
num_repeats = 2
54+
num_repeats = 1
5355
slurm_flags_extra = {"time": "01:00:00", "qos": qos_selected}
5456

5557
slurm_pos_args_options = [
5658
["profiling/launch_benchmark.sh"],
5759
config_list,
5860
model_list,
5961
max_length_list,
60-
batch_size,
62+
per_device_batch_size,
6163
]
6264
timestamp = int(time.time())
6365

profiling/launch_benchmark.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ profiling/benchmark.py \
2828
--yaml_path $1 \
2929
--model_name $2 \
3030
--max_length $3 \
31-
--training_batch_size $4
31+
--per_device_batch_size $4
3232

3333
# clean up benchmarking artifacts as ops have requested
3434
rm -rf /dev/shm/lora-benchmark

0 commit comments

Comments
 (0)