Skip to content

Commit

Permalink
[GPUPS]fix merge_grad&push_sparse (PaddlePaddle#25)
Browse files Browse the repository at this point in the history
* fix merge_grad

* fix push_sparse

* fix push_sparse

* fix size_t

* change typo
  • Loading branch information
zmxdream authored Jun 27, 2022
1 parent 6db0596 commit bdd38ba
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ __global__ void merge_gradient_kernel(const uint32_t* offset,

for (int j = 1; j < num; ++j) {
ori_index = index[start + j];
in = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.add_basic_field(lhs, in);
FeaturePushValue& rhs = *(FeaturePushValue*)(input + size_t(ori_index) * grad_value_size);
merger_.add_basic_field(lhs, rhs);
}
}

Expand Down Expand Up @@ -954,7 +954,7 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
auto stream = resource_->local_stream(gpu_num, 0);

size_t grad_value_size =
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));
TYPEALIGN(8, sizeof(FeaturePushValue) + (max_mf_dim_ * sizeof(float)));

// int h_left[total_gpu]; // NOLINT
// int h_right[total_gpu]; // NOLINT
Expand All @@ -976,14 +976,8 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
auto d_shard_keys = memory::Alloc(place, len * sizeof(KeyType));
KeyType* d_shard_keys_ptr = reinterpret_cast<KeyType*>(d_shard_keys->ptr());

GradType* d_shard_grads_ptr;
if (!multi_mf_dim_) {
auto d_shard_grads = memory::Alloc(place, len * sizeof(GradType));
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
} else {
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());
}
auto d_shard_grads = memory::Alloc(place, len * grad_value_size);
GradType* d_shard_grads_ptr = reinterpret_cast<GradType*>(d_shard_grads->ptr());

int uniq_len = len;
merge_grad(gpu_num, d_keys, d_grads, NULL, len, uniq_len);
Expand Down

0 comments on commit bdd38ba

Please sign in to comment.