Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade for bscvrq to support non-quant #113

Open
wants to merge 3 commits into
base: paddlebox
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion paddle/fluid/operators/fused/fused_seqpool_cvm_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,26 @@ struct sum_pooling_concate {
}
};

// add for bscvrq
// embedx_concate_filter:true && quant_ratio_valid=false && need_filter=true && embed_threshold_filter=false
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个跟上面的可以合并成一个,模版参数写成<false, need_filter, false>

template <typename T>
struct sum_pooling_concate<T, false, false, true> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need_filter=true or false?
need_filter和embed_threshold_filter是否反了?

static __device__ inline bool filter(T* local_x,
int in_dim_size,
float show_coeff,
float clk_coeff,
float threshold,
int cvm_offset,
float embed_threshold) {
auto &show = local_x[0];
auto &click = local_x[1];
if ((show - click) * show_coeff + click * clk_coeff < threshold) {
return true;
}
return false;
}
};

// embedx_concate_filter:true && quant_ratio_valid=true && need_filter=true && embed_threshold_filter=false
template <typename T>
struct sum_pooling_concate<T, true, true, false> {
Expand Down Expand Up @@ -844,6 +864,96 @@ struct sum_pooling_concate<T, true, true, true> {
}
};

// 1)FusedSeqpoolKernelEmbedFilterEmbedxConcate
// embedx_concate_filter:true && quant_ratio_valid=false && need_filter=true && embed_threshold_filter=true
// 2)FusedSeqpoolKernelFilterEmbedxConcate
// embedx_concate_filter:true && quant_ratio_valid=false && need_filter=true && embed_threshold_filter=false
// 3)FusedSeqpoolKernelEmbedxConcate
// embedx_concate_filter:true && quant_ratio_valid=false && need_filter=false
template <typename T, bool use_cvm, bool clk_filter, bool need_filter,
bool embed_threshold_filter, typename T2>
struct do_sum_pooling_and_cvm<T, use_cvm, clk_filter, need_filter, false, embed_threshold_filter, true, T2> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缺少了do_sum_pooling_and_cvm_with_large_dim的实现

static __device__ inline void kernel(T* local_x,
T* local_result, int local_result_len,
float padding_value,
T2* sum_show_clk,
int batch_start, int batch_end,
int in_dim_size, int out_dim_size,
int dim_start_offset,
int seqid,
int max_seq_len,
int quant_ratio,
float quant_ratio_reciprocal,
float32x16_t &v_scale,
float show_coeff,
float clk_coeff,
float threshold,
int cvm_offset,
float embed_threshold,
int embed_thres_size,
int embedx_concate_size,
bool fix_ctr_to_click,
__global_ptr__ T* cur_x,
__global_ptr__ T* cur_y) {
int concate_index = 0;
for (int i = batch_start; i < batch_end; i += max_seq_len) {
// int len = min<int64_t>(batch_end - i, max_seq_len);
int len = min(batch_end - i, max_seq_len);
if (len <= 0)
continue;
mfence();

for (int j = 0; j < len; j++) {
mfence();
GM2LM(cur_x + (i + j) * in_dim_size, local_x, in_dim_size * sizeof(T));

bool is_filter = sum_pooling_concate<T, false, need_filter, embed_threshold_filter>::filter(
local_x, in_dim_size, show_coeff, clk_coeff, threshold, cvm_offset, embed_threshold);
if (is_filter) {
continue;
}

if (concate_index < embedx_concate_size) {
// first: sum pool
// copy

float32x16_t v_src1 = vload_lm_float32x16(local_x);
float32x16_t v_src2 = vload_lm_float32x16(local_x + 16);

vstore_lm_float32x16(local_result, v_src1);
vstore_lm_float32x16(local_result + 16, v_src2);

mfence_lm();
// cvm_offset = [0, 2]
for (int cvm_i = 0; cvm_i < cvm_offset; cvm_i++) {
local_result[cvm_i] = local_x[cvm_i];
}

// second: cvm
int cur_y_index = seqid * embedx_concate_size * out_dim_size + concate_index * out_dim_size;
cvm_engine<T, true, use_cvm, clk_filter, T2>::concat_cvm(local_result,
out_dim_size, dim_start_offset,
cur_y_index,
cur_y);
mfence();
concate_index += 1;
}
}
}

mfence();

// second: cvm
for (int i = concate_index; i < embedx_concate_size; i++) {
memset_value_float(local_result, local_result_len, padding_value);
int cur_y_index = seqid * embedx_concate_size * out_dim_size + i * out_dim_size;
LM2GM_ASYNC(local_result, cur_y + cur_y_index, out_dim_size * sizeof(T));
mfence();
}
}
};


// 1)FusedSeqpoolKernelEmbedQuantFilterEmbedxConcate
// embedx_concate_filter:true && quant_ratio_valid=true && need_filter=true && embed_threshold_filter=true
// 2)FusedSeqpoolKernelQuantFilterEmbedxConcate
Expand Down Expand Up @@ -3374,4 +3484,4 @@ template int sequence_sum_pool_cvm_with_conv_grad<float, int>(xpu::Context* ctx,
uint32_t slot_num,
int embedx_concate_size);
} // end namespace framework
} // end namespace paddle
} // end namespace paddle