Skip to content

DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling… #7489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

juyterman1000
Copy link

… meta key (max_mem)

@juyterman1000 juyterman1000 force-pushed the fix/dc-zero3-allgather-uneven-shards branch from eac514f to 1f39153 Compare August 15, 2025 00:32
@sfc-gh-truwase
Copy link
Collaborator

@juyterman1000 can you please address the formatting issue using https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md#prerequisites

std::vector<int64_t> host_counts(world_size);
for (int i = 0; i < world_size; ++i) {
host_counts[i] = all_counts[i].to(torch::kCPU).item<int64_t>();
if (host_counts[i] > max_count) { max_count = host_counts[i]; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on when ds_tensor.numel() of the same paramter can differ on different ranks? I think padding is already taken into account when the parameter is partitioned among the ranks (ref: https://github.com/deepspeedai/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L1664)

In case partition sizes do vary across ranks, can we fix that in partition_parameters.py to avoid synchronous communication here? launchAllGather() is on the critical path, so synchronous allgather can hurt performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the sharp catch., I’ve removed the synchronous size‑allgather from the hot path in launchAllGather() and now use a fixed‑count NCCL allgather, trimming any end padding to the true param size. To keep things better without paying the runtime cost, I added a one‑time registration‑time assertion that shard sizes match across ranks, if there’s ever a mismatch, we’ll catch it at source rather than synchronize in the critical path. Changes are in the updated PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can further optimize the code by making allgatherParam() allocating a buffer with padding in the first place (today it allocates a buffer of ds_shape which is the true size of the gathered parameter). With that we don't need any additional memcpy or GPU memory allocation/deallocation. Instead we can slice the gathered output_buf before returning it. My understanding is that torch can correctly track the refcount to the underlying buffers even living tensors use only part of them, but correct me if I'm wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eternalNight thanks for the suggestion. @juyterman1000 if you agree with this, do you want to address in a follow up PR? A benefit of a follow up PR is that it could document the perf benefit of the optimization separately from functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @juyterman1000,

Thank you for the PR! As some changes are unclear to me, can you explain a bit more?
You now added an assertion to ensure the even sharding, which totally makes sense to me. Do we still need the changes launchAllGather? The additional memory allocation and copy might cause the significant overhead in some cases.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eternalNight Yes, we can allocate a buffer sized to world_size * shard_elems up front and slice it to the true size on return. PyTorch views hold a reference to the underlying storage; returning a sliced view does not break refcounting. We can cache the padded buffer per param to avoid repeat allocations. @sfc-gh-truwase . Agreed on the follow up.I’ll include micro-benchmarks showing the removal of 1 alloc + 1 memcpy per all-gather and any other gains. @tohtana With the even-sharding assertion in place, we don’t need the extra copy logic in launchAllGather(). We can issue a direct AllGather with a uniform shard ,elems count into the padded buffer and return a view of the first true_numel elements reshaped to the original param shape. The symmetric-memory path will stay as-is. This is to avoid additional copy overhead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juyterman1000 The path using symmetric memory is experimental and not well optimized. So we need to keep non-symmetric memory path as the choice for the best performance.
If the allocation and copy are for uneven partitioning and the assertion block such an uneven partitioning, why can't we remove them?

@juyterman1000
Copy link
Author

@juyterman1000 can you please address the formatting issue using https://github.com/deepspeedai/DeepSpeed/blob/master/CONTRIBUTING.md#prerequisites

Thanks for the check. I followed the formatting prerequisites in the contributing guide and ran the full pre-commit suite. I’ve pushed the updates.

@sfc-gh-truwase sfc-gh-truwase requested a review from tohtana August 18, 2025 16:13
const int64_t shard_elems = ds_tensor.numel();

// Perform all-gather directly into the pre-allocated padded output buffer
ncclResult_t result = ncclAllGather(ds_tensor.flatten().data_ptr(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why replacing .contiguous() with .flatten()? .contiguous() makes sure that the underlying storage is contiguous which is required by nccl. .flatten() is a view-change and does not guarantee that.

Note: I believe the sharded tensors are already contiguous as they are already defragmented by DeepSpeedZeroOptimizer_Stage3.defragment(), but adding a .contiguous() does not hurt anyway and may help later when the layout of sharded tensors is changed.

}

at::Tensor output_buf;
if (param_registry_->hasGatheredParam(ds_id)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when isValid(ds_id) is false while hasGatheredParam(ds_id) is true. They are both set at the end of launchAllGather(), and releasing a gathered param will unset the valid flag in unregisterGatheredParam().

… meta key (max_mem)

Signed-off-by: Abhishek <dalakotiashu150@gmail.com>
…s at registration (max_mem)

Signed-off-by: Abhishek <dalakotiashu150@gmail.com>
…iew; launchAllGather issues direct NCCL AllGather for uniform shards; add registration-time uniform-shard validation

Signed-off-by: Abhishek <dalakotiashu150@gmail.com>
Signed-off-by: Abhishek <dalakotiashu150@gmail.com>
@juyterman1000 juyterman1000 force-pushed the fix/dc-zero3-allgather-uneven-shards branch from 34df823 to ffa2aba Compare August 22, 2025 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants