diff --git a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps index 24a8fa84468f4..34f25cb0389ce 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps +++ b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps @@ -30,7 +30,7 @@ limitations under the License. */ #include "xpu/kernel/xtdk_simd.h" #ifdef TRACE_PROFILE -// #include "xpu/kernel/xtdk_io.h" +#include "xpu/kernel/xtdk_io.h" #include // The producer side. @@ -70,6 +70,15 @@ struct ExpandPushGetOp { } }; +struct ExpandPushEmdGetOp { + __device__ float get(float* expand, const int& row, + const int& expand_id, + const int& hidden, + const int& expand_dim) const { + return expand[row * (hidden + expand_dim) + hidden + expand_id]; + } +}; + template __device__ void set_byfloat(float* dest, const T& val) { (*reinterpret_cast(dest)) = val; @@ -340,6 +349,152 @@ __global__ void PullCopyNNCross(const TEmbedxOp* op, } } +template +__global__ void PullCopyNNCrossWithEmb(const TEmbedxOp* op, + const float scale, + const boxps::FeaturePullOffset* info, + int* total_dims, + unsigned long long* dst_vals, + const int* key2slot, + float* total_values, + const uint32_t* restore_idx, + const int total_length, + const int max_cols_num, + const int hidden_size, + const int expand_embed_dim, + const int pull_float_num, + const int skip_offset, + const int cvm_offset, + const int slot_num) { + int cid = core_id(); + int ncores = core_num(); + if (cid >= ncores) { + return; + } + int thread_id = cluster_id() * ncores + cid; + int nthreads = cluster_num() * ncores; + + const int buf_length = 5; + int per_thread_len = roundup_div(total_length, nthreads); + int per_thread_loop_count = roundup_div(per_thread_len, buf_length); + int per_thread_per_loop_len = roundup_div(per_thread_len, per_thread_loop_count); + + __local__ float lm_total_values[buf_length * pull_float_num]; + __local__ float lm_dst_vals[buf_length * hidden_size]; + __local__ float lm_dst_expand_vals[buf_length * (hidden_size + expand_embed_dim)]; + __local__ int lm_key2slot[buf_length]; + __local__ int lm_total_dims[buf_length]; + __local__ uint32_t lm_restore_idx[buf_length]; + __local__ boxps::FeaturePullOffset lm_info[1]; + __local__ TEmbedxOp lm_op[1]; + + const int max_slot_num = 1000; + int sm_slot_len = min(max_slot_num, slot_num); + __shared__ uint64_t sm_dst_vals_ptr[max_slot_num]; + __shared__ uint64_t sm_dst_expand_vals_ptr[max_slot_num]; + for (int i = cid; i < sm_slot_len; i += ncores) { + GM2SM(dst_vals + i, sm_dst_vals_ptr + i, sizeof(uint64_t)); + GM2SM(dst_vals + slot_num + i, sm_dst_expand_vals_ptr + i, sizeof(uint64_t)); + } + mfence(); + xpu_sync_all(); + + __local__ uint64_t lm_dst_vals_ptr[1]; + for(int i=0;i= total_length) { + return; + } + + int len = min(per_thread_per_loop_len, total_length - gm_offset); + if(restore_idx != nullptr) { + GM2LM(restore_idx + gm_offset, lm_restore_idx, len * sizeof(uint32_t)); + } + int pos = (restore_idx != nullptr) ? lm_restore_idx[gm_offset] : gm_offset; + GM2LM(total_values + pos * pull_float_num, lm_total_values, len * pull_float_num * sizeof(float)); + GM2LM(total_dims + gm_offset, lm_total_dims, len * sizeof(int)); + GM2LM(key2slot + gm_offset, lm_key2slot, len * sizeof(int)); + + for (int j = 0; j < len; j++) { + // mfence(); + // cvm offset + for (int k = 0; k < cvm_offset; ++k) { + //TODO:consider xpu_value[slot_id]==nullptr? + if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) { + lm_dst_vals[j * hidden_size + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k]; + } + if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) { + lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k]; + } + } + + // embedx + // embedx flags + expand flags && *(keys[x] + y) != 0 && *(keys[x] + y) + int embedx_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].embedx_size])); + // int embedx_size = 0; + // TODO: expand_size = expand_embed_dim? + int expand_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].expand_size])); + lm_total_dims[j] = static_cast(embedx_size > 0) | static_cast((expand_size > 0) << 1); + + if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) { + for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) { + lm_op[0].copy(lm_dst_vals + j * hidden_size + k, + lm_total_values + j * pull_float_num + lm_info[0].embedx, + k - cvm_offset, + scale); + } + + for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) { + lm_dst_vals[j * hidden_size + k] = 0; + } + } + + if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) { + for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) { + lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k, + lm_total_values + j * pull_float_num + lm_info[0].embedx, + k - cvm_offset, + scale); + } + + for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) { + lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0; + } + } + + // expand + if (sm_dst_expand_vals_ptr[lm_key2slot[j]] == 0) { + continue; + } + + for (int k = hidden_size; k < hidden_size + expand_size; ++k) { + lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k, + lm_total_values + j * pull_float_num + lm_info[0].expand, + k - hidden_size, + scale); + } + for (int k = hidden_size + expand_size; k < max_cols_num; ++k) { + lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0; + } + } + mfence(); + + LM2GM(lm_total_dims, total_dims + gm_offset, len * sizeof(int)); + LM2GM(lm_dst_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + gm_offset * hidden_size), len * hidden_size * sizeof(float)); + LM2GM(lm_dst_expand_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + total_length * hidden_size + gm_offset * (hidden_size + expand_embed_dim)), len * (hidden_size + expand_embed_dim) * sizeof(float)); + mfence(); + } +} + template inline void FeaturePullCopyNNCross( const paddle::platform::Place& place, @@ -405,9 +560,22 @@ inline void FeaturePullCopyNNCross( cvm_offset, slot_num); } else { - // PullCopyNNCrossWithEmb - // TODO: - CHECK(false) << "PullCopyNNCrossWithEmb not implement"; + PullCopyNNCrossWithEmb<<<8, 64, stream>>>(d_op, + scale, + info, + total_dims, + reinterpret_cast(d_xpu_values), + key2slot, + total_values_xpu, + xpu_restore_idx, + total_length, + (hidden_size + expand_embed_dim), + hidden_size, + expand_embed_dim, + pull_float_num, + skip_offset, + cvm_offset, + slot_num); } xpu_free(d_xpu_values); xpu_wait(stream); @@ -816,21 +984,18 @@ inline void FeaturePushCopyNNCross( auto ctx_xpu = static_cast(dev_ctx)->x_context(); auto stream = ctx_xpu->xpu_stream; - auto d_op_tmp = memory::Alloc(place, sizeof(TExpandPushGetOp)); - TExpandPushGetOp* d_op = reinterpret_cast(d_op_tmp->ptr()); - memory::Copy(place, - d_op, - platform::CPUPlace(), - op, - sizeof(TExpandPushGetOp)); - #ifdef TRACE_PROFILE TRACE_SCOPE_START("PushCopyNNCross", xpu_wait(stream)); #endif if (expand_only) { - // TODO: - // if (d_sort_idx != nullptr){ - // } + ExpandPushGetOp op; + auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushGetOp)); + ExpandPushGetOp* d_op = reinterpret_cast(d_op_tmp->ptr()); + memory::Copy(place, + d_op, + platform::CPUPlace(), + &op, + sizeof(ExpandPushGetOp)); PushCopyNNCross<<<8, 64, stream>>>(d_op, info, reinterpret_cast(gm_src),//src @@ -848,9 +1013,30 @@ inline void FeaturePushCopyNNCross( skip_offset, bs); } else { - // PullCopyNNCrossWithEmb - // TODO: - CHECK(false) << "PullCopyNNCrossWithEmb not implement"; + ExpandPushEmdGetOp op; + auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushEmdGetOp)); + ExpandPushEmdGetOp* d_op = reinterpret_cast(d_op_tmp->ptr()); + memory::Copy(place, + d_op, + platform::CPUPlace(), + &op, + sizeof(ExpandPushEmdGetOp)); + PushCopyNNCross<<<8, 64, stream>>>(d_op, + info, + reinterpret_cast(gm_src),//src + total_dims, + key2slot, + slot_vector, + slot_inner_offset, + push_grad_values,//dst + total_length, + hidden_size, + expand_embed_dim, + slot_num, + push_float_num, + cvm_offset, + skip_offset, + bs); } #ifdef TRACE_PROFILE xpu_wait(stream);