Skip to content

Conversation

@abhinavgoel95
Copy link
Contributor

@abhinavgoel95 abhinavgoel95 commented Jan 3, 2025

Description

This change adds support for sequence sharding along the tensor parallel axis. The benefits of this approach are available in this paper. This has performance benefits over the default tensor sharding approach because we do not need an all-reduce when computing LayerNorms or RMSNorms.

Tests

I have run the llama2-7b config with this change and observed a 7% end-to-end speedup on 8x H100 GPUs.

Baseline:

python3 /opt/workspace/maxtext/MaxText/train.py /opt/maxtext/MaxText/configs/base.yml model_name=llama2-7b per_device_batch_size=0.125 steps=15 scan_layers=true monitor_goodput=false enable_goodput_recording=false remat_policy=minimal_flash attention=dot_product max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false enable_checkpointing=false ici_fsdp_parallelism=1 ici_tensor_sequence_parallelism=1 ici_tensor_parallelism=8 base_output_directory=local_train dataset_path=local dataset_type=synthetic hardware=gpu run_name=testing-tp profiler=nsys skip_first_n_steps_for_profiler=4 profiler_steps=3

A snippet of output:

Memstats: step 14:
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:0
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:1
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:2
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:3
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:4
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:5
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:6
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:7
completed step: 14, seconds: 0.166, TFLOP/s/device: 142.405, Tokens/s/device: 3090.052, total_weights: 4096, loss: 0.000

With changes:

python3 /opt/workspace/maxtext/MaxText/train.py /opt/maxtext/MaxText/configs/base.yml model_name=llama2-7b per_device_batch_size=0.125 steps=15 scan_layers=true monitor_goodput=false enable_goodput_recording=false remat_policy=minimal_flash attention=dot_product max_target_length=4096 use_iota_embed=true logits_dot_in_fp32=false enable_checkpointing=false ici_fsdp_parallelism=1 ici_tensor_sequence_parallelism=8 ici_tensor_parallelism=1 base_output_directory=local_train dataset_path=local dataset_type=synthetic hardware=gpu run_name=testing-sp profiler=nsys skip_first_n_steps_for_profiler=4 profiler_steps=3
Memstats: step 14:
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:0
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:1
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:2
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:3
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:4
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:5
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:6
	Using (GB) 9.42 / 71.2 (13.230337%) on cuda:7
completed step: 14, seconds: 0.152, TFLOP/s/device: 154.844, Tokens/s/device: 3359.977, total_weights: 4096, loss: 0.000

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, thank you for this contribution!

Can we either

  1. Add an assert that checks that tensor_sequence is only used with llama2 decoder block
  2. Add support for tensor_sequence for all decoder blocks

@abhinavgoel95
Copy link
Contributor Author

@gobbleturk done (2). Thank you for your feedback.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thank you!

@abhinavgoel95
Copy link
Contributor Author

Fixed issues related to the linter.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you need to add tensor_sequence in pyconfig.py in at least 3 places (can search tensor_parallelism for relevant places)

raw_keys["dcn_fsdp_transpose_parallelism"],
raw_keys["dcn_sequence_parallelism"],
raw_keys["dcn_tensor_parallelism"],
raw_keys["dcn_tensor_sequence_parallelism"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add "tensor_sequence" to the mesh axes and data axes lists for PP right below

maxtext/MaxText/pyconfig.py

Lines 599 to 600 in 30f3988

mesh_axes = ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]
data_sharding = [["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]]

@copybara-service copybara-service bot merged commit 577ca8d into AI-Hypercomputer:main Jan 4, 2025
2 of 3 checks passed
@shuningjin shuningjin mentioned this pull request Aug 29, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants