Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions paddle/fluid/framework/fleet/fleet_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,9 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/io/fs.h"

#include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"

namespace paddle {
namespace framework {
Expand Down
54 changes: 53 additions & 1 deletion paddle/fluid/framework/fleet/heter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,69 @@ namespace framework {

class HeterContext {
public:
~HeterContext() {
for (size_t i = 0; i < mutex_.size(); ++i) {
delete mutex_[i];
}
mutex_.clear();
}
Scope* scope_{nullptr};
std::vector<std::vector<FeatureKey>> feature_keys_;
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> value_ptr_;
std::vector<std::vector<FeatureValue>> feature_values_;
std::vector<std::vector<FeatureValue>> device_values_;
std::vector<std::vector<FeatureKey>> device_keys_;
std::vector<std::mutex*> mutex_;

uint32_t shard_num_ = 37;
uint64_t size() {
uint64_t total_size = 0;
for (auto& keys : feature_keys_) {
total_size += keys.size();
}
return total_size;
}
void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; }
uint32_t ShardNum() { return shard_num_; }
void init(int shard_num, int device_num) {
shard_num_ = shard_num;
feature_keys_.resize(shard_num_);
value_ptr_.resize(shard_num_);

device_values_.resize(device_num);
device_keys_.resize(device_num);
mutex_.resize(device_num);
for (size_t i = 0; i < mutex_.size(); ++i) {
mutex_[i] = new std::mutex();
}
}
void batch_add_keys(const std::vector<std::vector<uint64_t>>& thread_keys) {
assert(thread_keys.size() == feature_keys_.size());

for (uint32_t i = 0; i < shard_num_; i++) {
int idx = 0;
idx = feature_keys_[i].size();
feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size());
for (uint64_t j = 0; j < thread_keys[i].size(); j++) {
feature_keys_[i][idx + j] = thread_keys[i][j];
}
}
}
void UniqueKeys() {
std::vector<std::thread> threads;
auto unique_func = [this](int i) {
auto& cur_keys = feature_keys_[i];
std::sort(cur_keys.begin(), cur_keys.end());
std::vector<FeatureKey>::iterator it;
it = std::unique(cur_keys.begin(), cur_keys.end());
cur_keys.resize(std::distance(cur_keys.begin(), it));
};
for (uint32_t i = 0; i < shard_num_; i++) {
threads.push_back(std::thread(unique_func, i));
}
for (std::thread& t : threads) {
t.join();
}
}
};

} // end namespace framework
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/feature_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct FeatureValue {
float lr_g2sum;
int mf_size;
float mf[MF_DIM + 1];
uint64_t cpu_ptr;

friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <glog/logging.h>
#include <limits>
#include <memory>
#include <vector>
#include "common_value.h" // NOLINT
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
Expand Down Expand Up @@ -47,6 +49,7 @@ class HashTable {
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
cudaStream_t stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);

template <typename GradType, typename Sgd>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Expand All @@ -60,5 +63,5 @@ class HashTable {
};
} // end namespace framework
} // end namespace paddle
#include "hashtable.tpp"
#include "hashtable_inl.h"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,41 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals, len);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
container_->prefetch(cudaCpuDeviceId, stream);
size_t num = container_->size();
KeyType unuse_key = std::numeric_limits<KeyType>::max();
thrust::pair<KeyType, ValType>* kv = container_->data();
for (size_t i = 0; i < num; ++i) {
if (kv[i].first == unuse_key) {
continue;
}
ValType& gpu_val = kv[i].second;
auto* downpour_value =
(paddle::ps::DownpourFixedFeatureValue*)(gpu_val.cpu_ptr);
int downpour_value_size = downpour_value->size();
if (gpu_val.mf_size > 0 && downpour_value_size == 7) {
downpour_value->resize(gpu_val.mf_size + downpour_value_size);
}
float* cpu_val = downpour_value->data();
cpu_val[0] = 0;
cpu_val[1] = gpu_val.delta_score;
cpu_val[2] = gpu_val.show;
cpu_val[3] = gpu_val.clk;
cpu_val[4] = gpu_val.lr;
cpu_val[5] = gpu_val.lr_g2sum;
cpu_val[6] = gpu_val.slot;
if (gpu_val.mf_size > 0) {
for (int x = 0; x < gpu_val.mf_size; x++) {
cpu_val[x + 7] = gpu_val.mf[x];
}
}
}

container_->prefetch(devid, stream);
}

template <typename KeyType, typename ValType>
template <typename GradType, typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
Expand Down
77 changes: 75 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <thread>
#include <vector>
#include "cub/cub.cuh"
#include "hashtable.h"
#include "heter_resource.h"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh"
#include "paddle/fluid/framework/fleet/heter_ps/optimizer.cuh.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/place.h"
#include "thrust/pair.h"

Expand Down Expand Up @@ -67,11 +69,38 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);

template <typename Sgd>
void push_sparse_multi_node(int num, KeyType* d_keys, GradType* d_grads,
size_t len, Sgd& sgd);

template <typename Sgd>
void update_one_table(int num, KeyType* d_keys, GradType* d_grads, size_t len,
Sgd& sgd);

int gather_one_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);

int gather_multi_node_grad(int num, KeyType* d_keys, GradType* d_grads,
int len);

int log2i(int x);

void set_nccl_comm_and_size(const std::vector<ncclComm_t>& inner_comms,
const std::vector<ncclComm_t>& inter_comms,
int comm_size) {
nccl_inner_comms_ = inner_comms;
nccl_inter_comms_ = inter_comms;
node_size_ = comm_size;
}

bool need_transfer(int send_id, int receive_id) {
return ((send_id / 4 != receive_id / 4) && (send_id + 4) % 8 != receive_id);
}

// void dump_to_cpu(int index);

void end_pass();

int get_transfer_devid(int send_id) { return (send_id + 4) % 8; }

struct Node {
Expand All @@ -89,6 +118,44 @@ class HeterComm {
std::vector<Node> nodes_;
};

struct LocalStorage {
LocalStorage() {}
void init(int size, int dev_id) {
place_ = platform::CUDAPlace(dev_id);
alloc(size, true);
}

void alloc(int size, bool force = false) {
if (force || size > all_keys_mem->size()) {
all_keys_mem.reset();
all_grads_mem.reset();
all_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
all_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
all_keys = reinterpret_cast<KeyType*>(all_keys_mem->ptr());
all_grads = reinterpret_cast<GradType*>(all_grads_mem->ptr());
}
if (force || size > local_keys_mem->size()) {
local_keys_mem.reset();
local_grads_mem.reset();
local_keys_mem = memory::AllocShared(place_, size * sizeof(KeyType));
local_grads_mem = memory::AllocShared(place_, size * sizeof(GradType));
local_keys = reinterpret_cast<KeyType*>(local_keys_mem->ptr());
local_grads = reinterpret_cast<GradType*>(local_grads_mem->ptr());
}
}

platform::CUDAPlace place_;
std::shared_ptr<memory::Allocation> all_keys_mem;
std::shared_ptr<memory::Allocation> all_grads_mem;
KeyType* all_keys;
GradType* all_grads;

std::shared_ptr<memory::Allocation> local_keys_mem;
std::shared_ptr<memory::Allocation> local_grads_mem;
KeyType* local_keys;
GradType* local_grads;
};

void init_path();
void create_storage(
int start_index, int end_index, int keylen, int vallen,
Expand All @@ -106,9 +173,15 @@ class HeterComm {
CustomGradMerger merger_;
int topo_aware_{1};
std::vector<std::vector<Path>> path_;
std::vector<LocalStorage> storage_;
int feanum_{1800 * 2048};
int multi_node_{1};
std::vector<ncclComm_t> nccl_inner_comms_;
std::vector<ncclComm_t> nccl_inter_comms_;
int node_size_;
};

} // end namespace framework
} // end namespace paddle
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm.tpp"
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h"
#endif
Loading