@@ -50,15 +50,17 @@ __global__ void insert_kernel(Table* table,
5050template <typename Table>
5151__global__ void insert_kernel (Table* table,
5252 const typename Table::key_type* const keys,
53- size_t len, char * pool, int start_index) {
53+ size_t len, char * pool, size_t feature_value_size,
54+ int start_index) {
5455 ReplaceOp<typename Table::mapped_type> op;
5556 thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;
5657
5758 const size_t i = blockIdx .x * blockDim .x + threadIdx .x ;
5859
5960 if (i < len) {
6061 kv.first = keys[i];
61- kv.second = (Table::mapped_type)(pool + (start_index + i) * 80 );
62+ uint64_t offset = uint64_t (start_index + i) * feature_value_size;
63+ kv.second = (Table::mapped_type)(pool + offset);
6264 auto it = table->insert (kv, op);
6365 assert (it != table->end () && " error: insert fails: table is full" );
6466 }
@@ -81,14 +83,29 @@ __global__ void search_kernel(Table* table,
8183template <typename Table>
8284__global__ void dy_mf_search_kernel (Table* table,
8385 const typename Table::key_type* const keys,
84- char * const vals, size_t len,
86+ char * vals, size_t len,
8587 size_t pull_feature_value_size) {
8688 const size_t i = blockIdx .x * blockDim .x + threadIdx .x ;
89+ // return;
8790 if (i < len) {
8891 auto it = table->find (keys[i]);
8992
9093 if (it != table->end ()) {
91- *(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second );
94+ uint64_t offset = i * pull_feature_value_size;
95+ FeatureValue* cur = (FeatureValue*)(vals + offset);
96+ FeatureValue& input = *(FeatureValue*)(it->second );
97+ cur->slot = input.slot ;
98+ cur->show = input.show ;
99+ cur->clk = input.clk ;
100+ cur->mf_dim = input.mf_dim ;
101+ cur->lr = input.lr ;
102+ cur->mf_size = input.mf_size ;
103+ cur->cpu_ptr = input.cpu_ptr ;
104+ cur->delta_score = input.delta_score ;
105+ cur->lr_g2sum = input.lr_g2sum ;
106+ for (int j = 0 ; j < cur->mf_dim + 1 ; ++j) {
107+ cur->mf [j] = input.mf [j];
108+ }
92109 }
93110 }
94111}
@@ -121,7 +138,7 @@ __global__ void dy_mf_update_kernel(Table* table,
121138 FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
122139 sgd.dy_mf_update_value (optimizer_config, (it.getter ())->second , *cur);
123140 } else {
124- printf (" yxf:: push miss key: %d" , keys[i]);
141+ printf (" warning: push miss key: %d" , keys[i]);
125142 }
126143 }
127144}
@@ -201,7 +218,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
201218template <typename KeyType, typename ValType>
202219template <typename StreamType>
203220void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
204- char * pool, size_t start_index,
221+ char * pool, size_t feature_value_size,
222+ size_t start_index,
205223 StreamType stream) {
206224 if (len == 0 ) {
207225 return ;
@@ -210,8 +228,8 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
210228 return ;
211229 }
212230 const int grid_size = (len - 1 ) / BLOCK_SIZE_ + 1 ;
213- insert_kernel<<<grid_size, BLOCK_SIZE_, 0 , stream>>> (container_, d_keys, len,
214- pool, start_index);
231+ insert_kernel<<<grid_size, BLOCK_SIZE_, 0 , stream>>> (
232+ container_, d_keys, len, pool, feature_value_size , start_index);
215233}
216234
217235template <typename KeyType, typename ValType>
@@ -319,10 +337,12 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
319337}
320338
321339template class HashTable <unsigned long , paddle::framework::FeatureValue>;
340+ template class HashTable <unsigned long , paddle::framework::FeatureValue*>;
322341template class HashTable <long , int >;
323342template class HashTable <unsigned long , int >;
324343template class HashTable <unsigned long , unsigned long >;
325344template class HashTable <unsigned long , long >;
345+ template class HashTable <unsigned long , long *>;
326346template class HashTable <long , long >;
327347template class HashTable <long , unsigned long >;
328348template class HashTable <long , unsigned int >;
@@ -332,6 +352,10 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::get<
332352 paddle::framework::FeatureValue* d_vals, size_t len,
333353 cudaStream_t stream);
334354
355+ template void
356+ HashTable<unsigned long , paddle::framework::FeatureValue*>::get<cudaStream_t>(
357+ const unsigned long * d_keys, char * d_vals, size_t len, cudaStream_t stream);
358+
335359template void HashTable<long , int >::get<cudaStream_t>(const long * d_keys,
336360 int * d_vals, size_t len,
337361 cudaStream_t stream);
@@ -357,6 +381,11 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::insert<
357381 const paddle::framework::FeatureValue* d_vals, size_t len,
358382 cudaStream_t stream);
359383
384+ template void HashTable<unsigned long , paddle::framework::FeatureValue*>::
385+ insert<cudaStream_t>(const unsigned long * d_keys, size_t len, char * pool,
386+ size_t feature_value_size, size_t start_index,
387+ cudaStream_t stream);
388+
360389template void HashTable<long , int >::insert<cudaStream_t>(const long * d_keys,
361390 const int * d_vals,
362391 size_t len,
@@ -382,11 +411,6 @@ template void HashTable<long, unsigned int>::insert<cudaStream_t>(
382411 const long * d_keys, const unsigned int * d_vals, size_t len,
383412 cudaStream_t stream);
384413
385- // template void HashTable<unsigned long,
386- // paddle::framework::FeatureValue>::insert<
387- // cudaStream_t>(const unsigned long* d_keys, size_t len, char* pool,
388- // size_t start_index, cudaStream_t stream);
389-
390414template void HashTable<unsigned long , paddle::framework::FeatureValue>::
391415 dump_to_cpu<cudaStream_t>(int devid, cudaStream_t stream);
392416
@@ -401,6 +425,16 @@ template void HashTable<unsigned long, paddle::framework::FeatureValue>::update<
401425 sgd,
402426 cudaStream_t stream);
403427
428+ template void
429+ HashTable<unsigned long , paddle::framework::FeatureValue*>::update<
430+ Optimizer<paddle::framework::FeatureValue,
431+ paddle::framework::FeaturePushValue>,
432+ cudaStream_t>(const unsigned long * d_keys, const char * d_grads, size_t len,
433+ Optimizer<paddle::framework::FeatureValue,
434+ paddle::framework::FeaturePushValue>
435+ sgd,
436+ cudaStream_t stream);
437+
404438// template void HashTable<unsigned long,
405439// paddle::framework::FeatureValue>::update<
406440// Optimizer<paddle::framework::FeatureValue,
0 commit comments