Skip to content

Commit

Permalink
[Refactor] Refine code and remove useless code.
Browse files Browse the repository at this point in the history
  • Loading branch information
liutongxuan committed Sep 14, 2022
1 parent eb3ed16 commit d183425
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 301 deletions.
49 changes: 31 additions & 18 deletions tensorflow/core/framework/embedding/batch.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
namespace tensorflow {
template<class V>
__global__ void BatchCopy(V** batch, V* val_base, int value_len,
int limit, V** default_value, bool* init_flags) {
int limit, V** default_value, bool* init_flags) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
int item_id = i / value_len;
int item_pos = i % value_len;

if (i < limit * value_len) {
if (init_flags[item_id]) {
*(batch[item_id] + item_pos) = *(default_value[item_id] + item_pos);
*(batch[item_id] + item_pos) =
*(default_value[item_id] + item_pos);
}
val_base[i] = *(batch[item_id] + item_pos);
}
Expand All @@ -26,8 +27,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS_ALL_INDEX)
#undef REGISTER_KERNELS_ALL_INDEX

template<class V>
__global__ void BatchUnpack(V** dev_value_address, V* memcpy_buffer_gpu,
int value_len, int limit) {
__global__ void BatchUnpack(V** dev_value_address,
V* memcpy_buffer_gpu, int value_len, int limit) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
int item_id = i / value_len;
int item_pos = i % value_len;
Expand All @@ -37,15 +38,19 @@ __global__ void BatchUnpack(V** dev_value_address, V* memcpy_buffer_gpu,
}
}

template __global__ void BatchUnpack<int>(int**, int*, int, int);
template __global__ void BatchUnpack<float>(float**, float*, int, int);
template __global__ void BatchUnpack<double>(double**, double*, int, int);
template __global__ void BatchUnpack<long long>(long long**, long long*, int, int);
template __global__ void BatchUnpack<int>(
int**, int*, int, int);
template __global__ void BatchUnpack<float>(
float**, float*, int, int);
template __global__ void BatchUnpack<double>(
double**, double*, int, int);
template __global__ void BatchUnpack<long long>(
long long**, long long*, int, int);

template<class V>
__global__ void SparseApplyAdagradGPU(V** a, V** v, V* g, float lr,
int embedding_dim, long long int limit,
bool* init_flags, V* default_value) {
int embedding_dim, long long int limit, bool* init_flags,
V* default_value) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
int item_id = i / embedding_dim;
int item_pos = i % embedding_dim;
Expand All @@ -55,25 +60,33 @@ __global__ void SparseApplyAdagradGPU(V** a, V** v, V* g, float lr,
*(a[item_id] + item_pos) = default_value[item_pos];
}
*(a[item_id] + item_pos) += g[i] * g[i];
*(v[item_id] + item_pos) -= lr * g[i] * rsqrt(*(a[item_id] + item_pos));
*(v[item_id] + item_pos) -=
lr * g[i] * rsqrt(*(a[item_id] + item_pos));
}
}

template __global__ void SparseApplyAdagradGPU<float>(float**, float**, float*, float, int, long long int, bool*, float*);
template __global__ void SparseApplyAdagradGPU<double>(double**, double**, double*, float, int, long long int, bool*, double*);
template __global__ void SparseApplyAdagradGPU<float>(
float**, float**, float*, float, int, long long int, bool*, float*);
template __global__ void SparseApplyAdagradGPU<double>(
double**, double**, double*, float, int, long long int, bool*, double*);

template<class V>
__global__ void CopyEmbedding(V** batch, V* batch_data_space, int total_dims, int limit) {
__global__ void CopyEmbedding(V** batch, V* batch_data_space,
int total_dims, int limit) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < limit * total_dims) {
batch_data_space[i] = *(batch[i / total_dims] + i % total_dims);
}
}

template __global__ void CopyEmbedding<int>(int**, int*, int, int);
template __global__ void CopyEmbedding<float>(float**, float*, int, int);
template __global__ void CopyEmbedding<double>(double**, double*, int, int);
template __global__ void CopyEmbedding<long long>(long long**, long long*, int, int);
template __global__ void CopyEmbedding<int>(
int**, int*, int, int);
template __global__ void CopyEmbedding<float>(
float**, float*, int, int);
template __global__ void CopyEmbedding<double>(
double**, double*, int, int);
template __global__ void CopyEmbedding<long long>(
long long**, long long*, int, int);

} // namespace tensorflow
#endif // TENSORFLOW_USE_GPU_EV
Expand Down
11 changes: 6 additions & 5 deletions tensorflow/core/framework/embedding/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ namespace tensorflow {

template<class V>
__global__ void BatchCopy(V** batch, V* val_base, int value_len,
int limit, V** default_value, bool* init_flags);
int limit, V** default_value, bool* init_flags);

template<class V>
__global__ void BatchUnpack(V** dev_value_address, V* memcpy_buffer_gpu,
int value_len, int limit);
int value_len, int limit);

template<class V>
__global__ void SparseApplyAdagradGPU(V** a, V** v, V* g, float lr,
int embedding_dim, long long int limit,
bool* init_flags, V* default_value);
int embedding_dim, long long int limit,
bool* init_flags, V* default_value);

template<class V>
__global__ void CopyEmbedding(V** batch, V* batch_data_space, int total_dims, int limit);
__global__ void CopyEmbedding(V** batch, V* batch_data_space,
int total_dims, int limit);

} // namespace tensorflow

Expand Down
28 changes: 19 additions & 9 deletions tensorflow/core/framework/embedding/embedding_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@ struct EmbeddingConfig {
bool record_freq;
bool record_version;

EmbeddingConfig(int64 emb_index = 0, int64 primary_emb_index = 0,
int64 block_num = 1, int slot_num = 0,
const std::string& name = "", int64 steps_to_live = 0,
int64 filter_freq = 0, int64 max_freq = 999999,
float l2_weight_threshold = -1.0, const std::string& layout = "normal",
int64 max_element_size = 0, float false_positive_probability = -1.0,
EmbeddingConfig(int64 emb_index = 0,
int64 primary_emb_index = 0,
int64 block_num = 1,
int slot_num = 0,
const std::string& name = "",
int64 steps_to_live = 0,
int64 filter_freq = 0,
int64 max_freq = 999999,
float l2_weight_threshold = -1.0,
const std::string& layout = "normal",
int64 max_element_size = 0,
float false_positive_probability = -1.0,
DataType counter_type = DT_UINT64,
int64 default_value_dim = 4096,
float default_value_no_permission = .0,
Expand All @@ -52,7 +58,8 @@ struct EmbeddingConfig {
record_version(record_version) {
if (max_element_size != 0 && false_positive_probability != -1.0){
kHashFunc = calc_num_hash_func(false_positive_probability);
num_counter = calc_num_counter(max_element_size, false_positive_probability);
num_counter = calc_num_counter(max_element_size,
false_positive_probability);
} else {
kHashFunc = 0;
num_counter = 0;
Expand All @@ -62,7 +69,8 @@ struct EmbeddingConfig {
}
}

int64 calc_num_counter(int64 max_element_size, float false_positive_probability) {
int64 calc_num_counter(int64 max_element_size,
float false_positive_probability) {
float loghpp = fabs(log(false_positive_probability));
float factor = log(2) * log(2);
int64 num_bucket = ceil(loghpp / factor * max_element_size);
Expand Down Expand Up @@ -90,7 +98,9 @@ struct EmbeddingConfig {
}

int64 total_num(int alloc_len) {
return block_num * (1 + (1 - normal_fix_flag) * slot_num) * (1 + normal_fix_flag * (alloc_len * (slot_num + 1) - 1));
return block_num *
(1 + (1 - normal_fix_flag) * slot_num) *
(1 + normal_fix_flag * (alloc_len * (slot_num + 1) - 1));
}

int64 get_filter_freq() {
Expand Down
Loading

0 comments on commit d183425

Please sign in to comment.