Skip to content

Commit b01e178

Browse files
authored
Merge pull request #10 from danleifeng/mul_dims
support dynamic mf dims
2 parents f24c3e9 + e6cfbde commit b01e178

30 files changed

+1303
-446
lines changed

paddle/fluid/framework/data_set.cc

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,24 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
121121
&data_feed_desc_);
122122
}
123123

124+
template <typename T>
125+
std::vector<std::string> DatasetImpl<T>::GetSlots() {
126+
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
127+
use_slots_.clear();
128+
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
129+
const auto& slot = multi_slot_desc.slots(i);
130+
if (slot.type() == "uint64" || slot.type() == "uint32") {
131+
use_slots_.push_back(slot.name());
132+
}
133+
}
134+
std::cout << "dataset use slots: ";
135+
for (auto s : use_slots_) {
136+
std::cout << s << " | ";
137+
}
138+
std::cout << " end " << std::endl;
139+
return use_slots_;
140+
}
141+
124142
template <typename T>
125143
void DatasetImpl<T>::SetChannelNum(int channel_num) {
126144
channel_num_ = channel_num;
@@ -303,12 +321,11 @@ static int compute_thread_batch_nccl(
303321
thread_avg_batch_num = static_cast<int>(offset.size() / thr_num);
304322
#ifdef PADDLE_WITH_GLOO
305323
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
306-
if (!gloo_wrapper->IsInitialized()) {
307-
VLOG(0) << "GLOO is not inited";
308-
gloo_wrapper->Init();
309-
}
310-
311324
if (gloo_wrapper->Size() > 1) {
325+
if (!gloo_wrapper->IsInitialized()) {
326+
VLOG(0) << "GLOO is not inited";
327+
gloo_wrapper->Init();
328+
}
312329
// adjust batch num per thread for NCCL
313330
std::vector<int> thread_avg_batch_num_vec(1, thread_avg_batch_num);
314331
std::vector<int64_t> total_instance_num_vec(1, total_instance_num);

paddle/fluid/framework/data_set.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class Dataset {
152152
virtual void DestroyPreLoadReaders() = 0;
153153
// set preload thread num
154154
virtual void SetPreLoadThreadNum(int thread_num) = 0;
155-
// separate train thread and dataset thread
155+
// seperate train thread and dataset thread
156156
virtual void DynamicAdjustChannelNum(int channel_num,
157157
bool discard_remaining_ins = false) = 0;
158158
virtual void DynamicAdjustReadersNum(int thread_num) = 0;
@@ -161,6 +161,8 @@ class Dataset {
161161
virtual void SetGraphDeviceKeys(
162162
const std::vector<int64_t>& h_device_keys) = 0;
163163

164+
virtual std::vector<std::string> GetSlots() = 0;
165+
164166
protected:
165167
virtual int ReceiveFromClient(int msg_type, int client_id,
166168
const std::string& msg) = 0;
@@ -249,6 +251,7 @@ class DatasetImpl : public Dataset {
249251
bool discard_remaining_ins = false);
250252
virtual void DynamicAdjustReadersNum(int thread_num);
251253
virtual void SetFleetSendSleepSeconds(int seconds);
254+
virtual std::vector<std::string> GetSlots();
252255
/* for enable_heterps_
253256
virtual void EnableHeterps(bool enable_heterps) {
254257
enable_heterps_ = enable_heterps;
@@ -324,6 +327,7 @@ class DatasetImpl : public Dataset {
324327
int64_t global_index_ = 0;
325328
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
326329
std::vector<T> input_records_; // only for paddleboxdatafeed
330+
std::vector<std::string> use_slots_;
327331
bool enable_heterps_ = false;
328332
int gpu_graph_mode_ = 1;
329333
std::vector<std::vector<int64_t>> gpu_graph_device_keys_;

paddle/fluid/framework/fleet/fleet_wrapper.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
6969
int node_num, int index) {
7070
#ifdef PADDLE_WITH_PSLIB
7171
if (!is_initialized_) {
72-
VLOG(3) << "Going to init worker";
72+
VLOG(0) << "Going to init worker";
7373
pslib_ptr_ = std::shared_ptr<paddle::distributed::PSlib>(
7474
new paddle::distributed::PSlib());
7575
pslib_ptr_->init_worker(dist_desc,
@@ -126,7 +126,7 @@ void FleetWrapper::GatherServers(const std::vector<uint64_t>& host_sign_list,
126126

127127
void FleetWrapper::GatherClients(const std::vector<uint64_t>& host_sign_list) {
128128
#ifdef PADDLE_WITH_PSLIB
129-
VLOG(3) << "Going to gather client ips";
129+
VLOG(0) << "Going to gather client ips";
130130
size_t len = host_sign_list.size();
131131
pslib_ptr_->gather_clients(const_cast<uint64_t*>(host_sign_list.data()), len);
132132
#endif
@@ -142,7 +142,7 @@ std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
142142

143143
void FleetWrapper::CreateClient2ClientConnection() {
144144
#ifdef PADDLE_WITH_PSLIB
145-
VLOG(3) << "Going to create client2client connection";
145+
VLOG(0) << "Going to create client2client connection";
146146
pslib_ptr_->create_client2client_connection(client2client_request_timeout_ms_,
147147
client2client_connect_timeout_ms_,
148148
client2client_max_retry_);
@@ -1054,7 +1054,8 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
10541054
int slot_offset = 0;
10551055
int grad_dim = 0;
10561056
// don't worry, user do not have to care about all these flags
1057-
if (accesor == "DownpourCtrAccessor") {
1057+
if (accesor == "DownpourCtrAccessor" ||
1058+
accesor == "DownpourCtrDymfAccessor") {
10581059
dump_slot = true;
10591060
slot_offset = 1;
10601061
grad_dim = fea_dim - 2;

paddle/fluid/framework/fleet/heter_context.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,6 @@ class HeterContext {
9595
}
9696
void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; }
9797
uint32_t ShardNum() { return shard_num_; }
98-
void init(int shard_num, int device_num) {
99-
shard_num_ = shard_num;
100-
feature_keys_.resize(shard_num_);
101-
value_ptr_.resize(shard_num_);
102-
device_task_ptr_.resize(shard_num_);
103-
device_task_keys_.resize(shard_num_);
104-
for (size_t i = 0; i < device_task_ptr_.size(); i++) {
105-
device_task_ptr_[i].resize(device_num);
106-
device_task_keys_[i].resize(device_num);
107-
}
108-
109-
device_values_.resize(device_num);
110-
device_keys_.resize(device_num);
111-
mutex_.resize(device_num);
112-
for (size_t i = 0; i < mutex_.size(); ++i) {
113-
mutex_[i] = new std::mutex();
114-
}
115-
}
11698

11799
void init(int shard_num, int device_num, int dim_num) {
118100
shard_num_ = shard_num;
@@ -129,11 +111,6 @@ class HeterContext {
129111
for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
130112
feature_dim_keys_[i].resize(dim_num);
131113
value_dim_ptr_[i].resize(dim_num);
132-
if (i == 0) {
133-
for (int j = 0; j < dim_num; j++) {
134-
feature_dim_keys_[i][j].push_back(0);
135-
}
136-
}
137114
}
138115
device_values_.resize(device_num);
139116
device_dim_values_.resize(device_num);

paddle/fluid/framework/fleet/heter_ps/feature_value.h

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,38 +32,67 @@ struct FeatureValue {
3232
float lr;
3333
float lr_g2sum;
3434
int mf_size;
35-
float mf[MF_DIM + 1];
35+
int mf_dim;
3636
uint64_t cpu_ptr;
37+
float mf[0];
3738

3839
friend std::ostream& operator<<(std::ostream& out, FeatureValue& val) {
3940
out << "show: " << val.show << " clk: " << val.clk << " slot: " << val.slot
40-
<< " lr: " << val.lr << " mf_size: " << val.mf_size << " mf:";
41-
for (int i = 0; i < val.mf_size; ++i) {
41+
<< " lr: " << val.lr << " mf_dim: " << val.mf_dim
42+
<< "cpuptr: " << val.cpu_ptr << " mf_size: " << val.mf_size << " mf:";
43+
for (int i = 0; i < val.mf_dim + 1; ++i) {
4244
out << " " << val.mf[i];
4345
}
4446
return out;
4547
}
48+
__device__ __forceinline__ void operator=(const FeatureValue& in) {
49+
delta_score = in.delta_score;
50+
show = in.show;
51+
clk = in.clk;
52+
slot = in.slot;
53+
lr = in.lr;
54+
lr_g2sum = in.lr_g2sum;
55+
mf_size = in.mf_size;
56+
mf_dim = in.mf_dim;
57+
cpu_ptr = in.cpu_ptr;
58+
for (int i = 0; i < mf_dim + 1; i++) {
59+
mf[i] = in.mf[i];
60+
}
61+
}
4662
};
4763

4864
struct FeaturePushValue {
4965
float show;
5066
float clk;
5167
int slot;
5268
float lr_g;
53-
float mf_g[MF_DIM];
69+
int mf_dim;
70+
float mf_g[0];
5471

55-
// __device__ __forceinline__ FeaturePushValue
56-
// operator+(const FeaturePushValue& a) const {
57-
// FeaturePushValue out;
58-
// out.slot = a.slot;
59-
// out.show = a.show + show;
60-
// out.clk = a.clk + clk;
61-
// out.lr_g = a.lr_g + lr_g;
62-
// for (int i = 0; i < MF_DIM; ++i) {
63-
// out.mf_g[i] = a.mf_g[i] + mf_g[i];
64-
// }
65-
// return out;
66-
// }
72+
__device__ __forceinline__ FeaturePushValue
73+
operator+(const FeaturePushValue& a) const {
74+
FeaturePushValue out;
75+
out.slot = a.slot;
76+
out.mf_dim = a.mf_dim;
77+
out.show = a.show + show;
78+
out.clk = a.clk + clk;
79+
out.lr_g = a.lr_g + lr_g;
80+
// out.mf_g = a.mf_g;
81+
for (int i = 0; i < out.mf_dim; ++i) {
82+
out.mf_g[i] = a.mf_g[i] + mf_g[i];
83+
}
84+
return out;
85+
}
86+
__device__ __forceinline__ void operator=(const FeaturePushValue& in) {
87+
show = in.show;
88+
clk = in.clk;
89+
slot = in.slot;
90+
lr_g = in.lr_g;
91+
mf_dim = in.mf_dim;
92+
for (int i = 0; i < mf_dim; i++) {
93+
mf_g[i] = in.mf_g[i];
94+
}
95+
}
6796
};
6897

6998
} // end namespace framework

paddle/fluid/framework/fleet/heter_ps/hashtable.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ class HashTable {
118118
StreamType stream);
119119

120120
template <typename StreamType>
121-
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
122-
StreamType stream);
121+
void insert(const KeyType* d_keys, size_t len, char* pool,
122+
size_t feature_value_size, size_t start_index, StreamType stream);
123123

124124
template <typename StreamType>
125125
void get(const KeyType* d_keys, ValType* d_vals, size_t len,

paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,17 @@ __global__ void insert_kernel(Table* table,
5050
template <typename Table>
5151
__global__ void insert_kernel(Table* table,
5252
const typename Table::key_type* const keys,
53-
size_t len, char* pool, int start_index) {
53+
size_t len, char* pool, size_t feature_value_size,
54+
int start_index) {
5455
ReplaceOp<typename Table::mapped_type> op;
5556
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
5657

5758
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
5859

5960
if (i < len) {
6061
kv.first = keys[i];
61-
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
62+
uint64_t offset = uint64_t(start_index + i) * feature_value_size;
63+
kv.second = (Table::mapped_type)(pool + offset);
6264
auto it = table->insert(kv, op);
6365
assert(it != table->end() && "error: insert fails: table is full");
6466
}
@@ -81,14 +83,29 @@ __global__ void search_kernel(Table* table,
8183
template <typename Table>
8284
__global__ void dy_mf_search_kernel(Table* table,
8385
const typename Table::key_type* const keys,
84-
char* const vals, size_t len,
86+
char* vals, size_t len,
8587
size_t pull_feature_value_size) {
8688
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
89+
// return;
8790
if (i < len) {
8891
auto it = table->find(keys[i]);
8992

9093
if (it != table->end()) {
91-
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
94+
uint64_t offset = i * pull_feature_value_size;
95+
FeatureValue* cur = (FeatureValue*)(vals + offset);
96+
FeatureValue& input = *(FeatureValue*)(it->second);
97+
cur->slot = input.slot;
98+
cur->show = input.show;
99+
cur->clk = input.clk;
100+
cur->mf_dim = input.mf_dim;
101+
cur->lr = input.lr;
102+
cur->mf_size = input.mf_size;
103+
cur->cpu_ptr = input.cpu_ptr;
104+
cur->delta_score = input.delta_score;
105+
cur->lr_g2sum = input.lr_g2sum;
106+
for (int j = 0; j < cur->mf_dim + 1; ++j) {
107+
cur->mf[j] = input.mf[j];
108+
}
92109
}
93110
}
94111
}
@@ -121,7 +138,7 @@ __global__ void dy_mf_update_kernel(Table* table,
121138
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
122139
sgd.dy_mf_update_value(optimizer_config, (it.getter())->second, *cur);
123140
} else {
124-
printf("yxf::push miss key: %d", keys[i]);
141+
printf("warning: push miss key: %d", keys[i]);
125142
}
126143
}
127144
}
@@ -201,7 +218,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
201218
template <typename KeyType, typename ValType>
202219
template <typename StreamType>
203220
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
204-
char* pool, size_t start_index,
221+
char* pool, size_t feature_value_size,
222+
size_t start_index,
205223
StreamType stream) {
206224
if (len == 0) {
207225
return;
@@ -210,8 +228,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
210228
return;
211229
}
212230
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
213-
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
214-
pool, start_index);
231+
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
232+
container_, d_keys, len, pool, feature_value_size, start_index);
215233
}
216234

217235
template <typename KeyType, typename ValType>
@@ -319,10 +337,12 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
319337
}
320338

321339
template class HashTable<unsigned long, paddle::framework::FeatureValue>;
340+
template class HashTable<unsigned long, paddle::framework::FeatureValue*>;
322341
template class HashTable<long, int>;
323342
template class HashTable<unsigned long, int>;
324343
template class HashTable<unsigned long, unsigned long>;
325344
template class HashTable<unsigned long, long>;
345+
template class HashTable<unsigned long, long*>;
326346
template class HashTable<long, long>;
327347
template class HashTable<long, unsigned long>;
328348
template class HashTable<long, unsigned int>;
@@ -332,6 +352,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
332352
paddle::framework::FeatureValue* d_vals, size_t len,
333353
cudaStream_t stream);
334354

355+
template void
356+
HashTable<unsigned long, paddle::framework::FeatureValue*>::get<cudaStream_t>(
357+
const unsigned long* d_keys, char* d_vals, size_t len, cudaStream_t stream);
358+
335359
template void HashTable<long, int>::get<cudaStream_t>(const long* d_keys,
336360
int* d_vals, size_t len,
337361
cudaStream_t stream);
@@ -357,6 +381,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
357381
const paddle::framework::FeatureValue* d_vals, size_t len,
358382
cudaStream_t stream);
359383

384+
template void HashTable<unsigned long, paddle::framework::FeatureValue*>::
385+
insert<cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
386+
size_t feature_value_size, size_t start_index,
387+
cudaStream_t stream);
388+
360389
template void HashTable<long, int>::insert<cudaStream_t>(const long* d_keys,
361390
const int* d_vals,
362391
size_t len,
@@ -382,11 +411,6 @@ template void HashTable<long, unsigned int>::insert<cudaStream_t>(
382411
const long* d_keys, const unsigned int* d_vals, size_t len,
383412
cudaStream_t stream);
384413

385-
// template void HashTable<unsigned long,
386-
// paddle::framework::FeatureValue>::insert<
387-
// cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
388-
// size_t start_index, cudaStream_t stream);
389-
390414
template void HashTable<unsigned long, paddle::framework::FeatureValue>::
391415
dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);
392416

@@ -401,6 +425,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
401425
sgd,
402426
cudaStream_t stream);
403427

428+
template void
429+
HashTable<unsigned long, paddle::framework::FeatureValue*>::update<
430+
Optimizer<paddle::framework::FeatureValue,
431+
paddle::framework::FeaturePushValue>,
432+
cudaStream_t>(const unsigned long* d_keys, const char* d_grads, size_t len,
433+
Optimizer<paddle::framework::FeatureValue,
434+
paddle::framework::FeaturePushValue>
435+
sgd,
436+
cudaStream_t stream);
437+
404438
// template void HashTable<unsigned long,
405439
// paddle::framework::FeatureValue>::update<
406440
// Optimizer<paddle::framework::FeatureValue,

0 commit comments

Comments
 (0)