Skip to content

[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

xrsrke
Copy link
Member

@xrsrke xrsrke commented Mar 10, 2025

Reproducing the paper "Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping" https://arxiv.org/abs/2409.15241

The losses are match after 20b tokens with 2m batch size, 20k steps, and fineweb dataset with 75% communication hiding for tensor parallelism,

image

The first PR is ready for review (i split it to two PRs), some left work for the next PR:

  • intra-layer overlapping (current overlapping communication within a layer), but if we do intra-layer overlapping, then we can almost totally overlapping the comm
  • create an fixed buffer to concat hidden states (1st image)
  • double check if there is cuda stream switching's overhead (4.3.1 in the 2nd image)
  • minimize kernel launch overhead (cuda graph? 4.3.2 in 2nd image)

Profiling results:

image

image

/fsx/phuc/new_workspace/experiments/nanotron_domino/profilings/exp7a11_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remove_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_remove_explicite_async_op_arg_and_commit_600f01/20250228-160428/ip-26-0-161-142_51797.1740758749919300440.pt.trace.json

xrsrke and others added 30 commits January 29, 2025 12:47
…_mbs2_and_gbs_300k_and_input_splitting_and_commit_23f2_but_remove_call_is_async_comm_twice_and_keep_not_async_bwd.layer_mlp_1__and_bwd.layer_attn_0
- execute backward comm in a separate stream
    - make commm stream in the backward pass wait for compute stream before run backward comm
- make WaitComm’s compute stream to wait for the comm stream
…omm, and remove torch.cuda.synchronize() in WaitComm
…e_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_commit_543ef56
…x_stream_not_sync_exp2a1c7_and_commit_23f2_and_75_percent_bwd_overlapping_with_cuda_stream_sync_bwd
@xrsrke xrsrke changed the title Xrsrke/exp7a13b0 domino revert from fix stream not sync exp2a1c7 and commit 23f2 and 75 percent bwd overlapping with cuda stream sync bwd but remove stream manager ctx [Feature] Hide 75% of the communication in tensor parallelism using DoMiNo Mar 10, 2025
BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}"
BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}"

_operation_context = threading.local()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BWD_ATTN_OP_NAME

because we recall these name many places in the code, I want to make it consistent, so if we change the name, we don't have to manually replace in other places

Comment on lines 14 to 17
"""
Determine whether a module (e.g., mlp, attention)
performs all-reduce asynchronously in tensor parallelism
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue the description of this function.. how do we determine it? what do we check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +10 to +24
class AsyncCommBucket:
"""
Store aynchronous communication operations.
"""

def __init__(self):
self._async_op: Dict[int, "dist.Work"] = {}
self._copy_async_op: Dict[int, "dist.Work"] = {}

def add(self, op_name: int, work: "dist.Work"):
assert op_name not in self._async_op, f"Operation with name: {op_name} already exists"
assert work is not None
self._async_op[op_name] = work
self._copy_async_op[op_name] = work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we don't have an equivalent of this class in torch? o.O

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked they have an "GradBucket" [link] in torch, but doesn't seem to have something that suits our need to store the dist async op


not_finished = []
for k, v in self._copy_async_op.items():
assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont like the mention of domino here. this CommBucket should be independent of domino

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@xrsrke xrsrke requested review from NouamaneTazi and zzhhjjj April 1, 2025 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants