Skip to content

DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling… #7489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions csrc/compile/z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
const at::Tensor& ds_tensor = param.getDSTensor();

if (symm_mem == nullptr) {
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();

// Perform all-gather directly into the pre-allocated padded output buffer
// NCCL requires contiguous storage; use .contiguous() explicitly
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
output_buf.data_ptr(),
ds_tensor.numel(),
shard_elems,
get_nccl_data_type(ds_tensor.scalar_type()),
nccl_comm_,
ag_stream_);
Expand Down Expand Up @@ -104,13 +110,30 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
at::Tensor allgatherParam(long ds_id,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }

const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
? param_registry_->getGatheredParam(ds_id)
: torch::empty(param.getShape(), ds_tensor.options());
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();
const int64_t padded_numel = static_cast<int64_t>(world_size) * shard_elems;
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));

if (param_registry_->isValid(ds_id)) {
// Return a view sliced to the true size with the original shape
auto base = param_registry_->getGatheredParam(ds_id);
return base.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}

at::Tensor output_buf;
if (param_registry_->hasGatheredParam(ds_id)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when isValid(ds_id) is false while hasGatheredParam(ds_id) is true. They are both set at the end of launchAllGather(), and releasing a gathered param will unset the valid flag in unregisterGatheredParam().

auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) { output_buf = existing; }
}
if (!output_buf.defined()) {
at::cuda::CUDAStreamGuard guard(ag_stream_);
output_buf = torch::empty({padded_numel}, ds_tensor.options());
}

assert(hasKey(ag_comp_done_events_, ds_id));
ag_comp_done_events_[ds_id]->record();
Expand All @@ -119,7 +142,10 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
launchAllGather(output_buf, ds_id, symm_mem);

ag_comm_done_events_[ds_id]->record(ag_stream_);
return output_buf;
// Return a view of the gathered padded buffer matching the true param shape
return output_buf.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}

void prefetchParamsFused(std::vector<int64_t> ds_ids,
Expand All @@ -133,11 +159,19 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
std::unordered_map<long, at::Tensor> output_bufs;
for (long ds_id : invalid_ds_ids) {
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();
const int64_t padded_numel = static_cast<int64_t>(world_size) * shard_elems;

if (param_registry_->hasGatheredParam(ds_id)) {
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
} else {
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) {
output_bufs[ds_id] = existing;
continue;
}
}
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
}

for (long ds_id : invalid_ds_ids) {
Expand Down Expand Up @@ -383,6 +417,27 @@ void register_z3_param(long ds_id,
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }

// Validate that shard sizes are uniform across ranks at registration time (not on the hot path)
// This ensures launchAllGather can assume uniform shard sizes without extra synchronization.
const int64_t local_count = ds_tensor.numel();
const int world_size = process_group->getSize();

auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
at::Tensor local_count_tensor = torch::tensor({local_count}, count_options);
std::vector<at::Tensor> all_counts(world_size);
for (int i = 0; i < world_size; ++i) { all_counts[i] = torch::empty_like(local_count_tensor); }
process_group->allgather(all_counts, local_count_tensor)->wait();

int64_t reference = local_count;
for (int i = 0; i < world_size; ++i) {
int64_t c = all_counts[i].to(torch::kCPU).item<int64_t>();
if (c != reference) {
throw std::runtime_error(
"ZeRO-3 registration error: non-uniform shard sizes detected across ranks. "
"Please check parameter partitioning.");
}
}
}

at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/compile/profilers/graph_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
n.meta["device_time"] = 0.0
n.meta["wall_time"] = 0.0
n.meta["alloc_mem"] = 0
n.meta["max_memory"] = 0
n.meta["max_mem"] = 0
n.meta["tensor_size"] = _node_size(n)
return super().run_node(n)

Expand Down