-
Notifications
You must be signed in to change notification settings - Fork 219
[Feature] Hide 75% of the communication in tensor parallelism using DoMiNo #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…_mbs2_and_gbs_300k_and_input_splitting_and_commit_23f2_but_remove_call_is_async_comm_twice_and_keep_not_async_bwd.layer_mlp_1__and_bwd.layer_attn_0
- execute backward comm in a separate stream - make commm stream in the backward pass wait for compute stream before run backward comm - make WaitComm’s compute stream to wait for the comm stream
…omm, and remove torch.cuda.synchronize() in WaitComm
…e_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_commit_543ef56
…x_stream_not_sync_exp2a1c7_and_commit_23f2_and_75_percent_bwd_overlapping_with_cuda_stream_sync_bwd
…eturning it directly in linear modules
BWD_ATTN_OP_NAME = "bwd.layer_attn_{}_batch_{}" | ||
BWD_MLP_OP_NAME = "bwd.layer_mlp_{}_batch_{}" | ||
|
||
_operation_context = threading.local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BWD_ATTN_OP_NAME
because we recall these name many places in the code, I want to make it consistent, so if we change the name, we don't have to manually replace in other places
""" | ||
Determine whether a module (e.g., mlp, attention) | ||
performs all-reduce asynchronously in tensor parallelism | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue the description of this function.. how do we determine it? what do we check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
class AsyncCommBucket: | ||
""" | ||
Store aynchronous communication operations. | ||
""" | ||
|
||
def __init__(self): | ||
self._async_op: Dict[int, "dist.Work"] = {} | ||
self._copy_async_op: Dict[int, "dist.Work"] = {} | ||
|
||
def add(self, op_name: int, work: "dist.Work"): | ||
assert op_name not in self._async_op, f"Operation with name: {op_name} already exists" | ||
assert work is not None | ||
self._async_op[op_name] = work | ||
self._copy_async_op[op_name] = work | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure we don't have an equivalent of this class in torch? o.O
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked they have an "GradBucket" [link] in torch, but doesn't seem to have something that suits our need to store the dist async op
src/nanotron/parallel/comm.py
Outdated
|
||
not_finished = [] | ||
for k, v in self._copy_async_op.items(): | ||
assert is_domino_async_comm(k) is True, f"Operation with name {k} wasn't executed asynchronously!" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont like the mention of domino here. this CommBucket should be independent of domino
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
Reproducing the paper "Domino: Eliminating Communication in LLM Training via Generic Tensor Slicing and Overlapping" https://arxiv.org/abs/2409.15241
The losses are match after 20b tokens with 2m batch size, 20k steps, and fineweb dataset with 75% communication hiding for tensor parallelism,
The first PR is ready for review (i split it to two PRs), some left work for the next PR:
Profiling results:
/fsx/phuc/new_workspace/experiments/nanotron_domino/profilings/exp7a11_like_exp7a6_but_remove_fwd_pass_cuda_syncronization_and_remove_cuda_syncronize_in_wait_comm_bwd_and_add_comm_syncronize_in_waitcomm_and_remove_explicite_async_op_arg_and_commit_600f01/20250228-160428/ip-26-0-161-142_51797.1740758749919300440.pt.trace.json