Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct FeatureValue {
float lr_g2sum;
int mf_size;
float mf[MF_DIM + 1];
uint64_t cpu_ptr;

friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <glog/logging.h>
#include <limits>
#include <memory>
#include <vector>
#include "common_value.h" // NOLINT
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
Expand Down Expand Up @@ -47,6 +49,7 @@ class HashTable {
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
cudaStream_t stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);

template <typename GradType, typename Sgd>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Expand All @@ -60,5 +63,5 @@ class HashTable {
};
} // end namespace framework
} // end namespace paddle
#include "hashtable.tpp"
#include "hashtable_inl.h"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals, len);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
container_->prefetch(cudaCpuDeviceId, stream);
size_t num = container_->size();
KeyType unuse_key = std::numeric_limits<KeyType>::max();
thrust::pair<KeyType, ValType>* kv = container_->data();
for (size_t i = 0; i < num; ++i) {
if (kv[i].first == unuse_key) {
continue;
}
ValType& gpu_val = kv[i].second;
auto* downpour_value =
(paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
downpour_value->resize(gpu_val.mf_size + downpour_value_size);
}
float* cpu_val = downpour_value->data();
cpu_val[0] = 0;
cpu_val[1] = gpu_val.delta_score;
cpu_val[2] = gpu_val.show;
cpu_val[3] = gpu_val.clk;
cpu_val[4] = gpu_val.lr;
cpu_val[5] = gpu_val.lr_g2sum;
cpu_val[6] = gpu_val.slot;
if (gpu_val.mf_size > 0) {
for (int x = 0; x < gpu_val.mf_size; x++) {
cpu_val[x + 7] = gpu_val.mf[x];
}
}
}

container_->prefetch(devid, stream);
}

template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "hashtable.h"
#include "heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/place.h"
Expand Down Expand Up @@ -72,6 +73,10 @@ class HeterComm {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}

// void dump_to_cpu(int index);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可以把注释删掉


void end_pass();

int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

struct Node {
Expand Down Expand Up @@ -110,5 +115,5 @@ class HeterComm {

} // end namespace framework
} // end namespace paddle
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,34 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
}
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::end_pass() {
int total_gpu = resource_->total_gpu();
std::vector<std::thread> threads;

auto dump_to_cpu_func = [this](int index) {
auto stream = resource_->local_stream(index, 0);
int dev_id = resource_->dev_id(index);
platform::CUDADeviceGuard guard(dev_id);
tables_[index]->dump_to_cpu(dev_id, stream);
};

for (int i = 0; i < total_gpu; ++i) {
threads.push_back(std::thread(dump_to_cpu_func, i));
}
for (auto& t : threads) {
t.join();
}
}

// template <typename KeyType, typename ValType, typename GradType>
// void HeterComm<KeyType, ValType, GradType>::dump_to_cpu(int index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可以把注释删掉

// auto stream = resource_->local_stream(index, 0);
// int dev_id = resource_->dev_id(index);
// platform::CUDADeviceGuard guard(dev_id);
// tables_[index]->dump_to_cpu(dev_id, stream);
//}

} // end namespace framework
} // end namespace paddle
#endif
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ int HeterPs::get_index_by_devid(int devid) {
return comm_->get_index_by_devid(devid);
}

void HeterPs::dump() {}
void HeterPs::end_pass() { comm_->end_pass(); }

void HeterPs::show_one_table(int gpu_num) { comm_->show_one_table(gpu_num); }

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"

#ifdef PADDLE_WITH_PSLIB

Expand All @@ -35,7 +35,7 @@ class HeterPs : public HeterPsBase {
size_t len) override;
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) override;
virtual void dump() override;
virtual void end_pass() override;
virtual int get_index_by_devid(int devid) override;
virtual void show_one_table(int gpu_num) override;
virtual void push_sparse(int num, FeatureKey* d_keys,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HeterPsBase {
virtual void build_ps(int num, FeatureKey* h_keys, FeatureValue* h_vals,
size_t len, size_t chunk_size, int stream_num) = 0;
virtual int get_index_by_devid(int devid) = 0;
virtual void dump() = 0;
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <vector>
#include <curand_kernel.h>
#include <vector>
#include "optimizer_conf.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"

Expand Down Expand Up @@ -111,8 +111,8 @@ class Optimizer {
curandState state;
curand_init(clock64(), tid_x, 0, &state);
for (int i = 0; i < MF_DIM; ++i) {
val.mf[i + 1] = (curand_uniform(&state)) *
optimizer_config::mf_initial_range;
val.mf[i + 1] =
(curand_uniform(&state)) * optimizer_config::mf_initial_range;
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/test_comm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.h"
#include "paddle/fluid/framework/fleet/heter_ps/heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/platform/cuda_device_guard.h"

using namespace paddle::framework;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
val.slot = ptr_val[6];
val.lr = ptr_val[4];
val.lr_g2sum = ptr_val[5];
val.cpu_ptr = (uint64_t)(task_ptrs[dev][j]);

if (dim > 7) {
val.mf_size = MF_DIM + 1;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ class PSGPUWrapper {
slot_vector_ = slot_vector;
}

void EndPass() { HeterPs_->end_pass(); }
void ShowOneTable(int index) { HeterPs_->show_one_table(index); }

private:
static std::shared_ptr<PSGPUWrapper> s_instance_;
Dataset* dataset_;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/ps_gpu_wrapper_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ void BindPSGPUWrapper(py::module* m) {
py::call_guard<py::gil_scoped_release>())
.def("init_gpu_ps", &framework::PSGPUWrapper::InitializeGPU,
py::call_guard<py::gil_scoped_release>())
.def("end_pass", &framework::PSGPUWrapper::EndPass,
py::call_guard<py::gil_scoped_release>())
.def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS,
py::call_guard<py::gil_scoped_release>());
} // end PSGPUWrapper
Expand Down