Skip to content

Enable subwarp only for unweighted #2051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions fbgemm_gpu/codegen/embedding_forward_split_kernel_v2_template.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,9 @@ __noinline__ __device__ void process_all_indices_small_Ls(
const uint32_t num_offsets = smem[params_offset + SAVED_PARAMS::P_num_offsets];
const uint32_t total_load_D = smem[params_offset + SAVED_PARAMS::P_total_load_D];
// Write zeros to the sample that L = 0
Vec4StepT<1, emb_t> accumulator;
for (uint32_t i = 0; i < num_offsets; ++i) {
memset(output + i * total_load_D, 0, sizeof(output_vec_t));
accumulator.store(output + i * total_load_D);
}
}
return;
Expand Down Expand Up @@ -299,7 +300,8 @@ __noinline__ __device__ void process_all_indices_small_Ls(
auto * __restrict__ const output = *reinterpret_cast<output_vec_t**>(&smem[params_offset + SAVED_PARAMS::P_outputs]);
const auto total_load_D = static_cast<uint32_t>(smem[params_offset + SAVED_PARAMS::P_total_load_D]);
if (process_d) {
memset(output + write_idx + threadIdx.x, 0, sizeof(output_vec_t));
Vec4StepT<1, emb_t> accumulator;
accumulator.store(output + write_idx + threadIdx.x);
}
write_idx += total_load_D;

Expand Down Expand Up @@ -746,13 +748,34 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
return;
}

bool is_small_L;
if (threadIdx.x == 0) {
// Use the small-L optimization if average L <= 8
is_small_L = (offsets[(t + 1) * B] - offsets[t * B]) <= (static_cast<index_t>(B) * 8);
const auto total_L = offsets[(t + 1) * B] - offsets[t * B];
const auto is_zero_total_L = total_L == 0;

// Short circuit for all zeros
if (is_zero_total_L) {
const uint32_t D_start = D_offsets[t] / VEC_WIDTH;
const uint32_t load_D = (D_offsets[t + 1] / VEC_WIDTH) - D_start;
const uint32_t num_warps_per_row = DIV_ROUND_UP(load_D, kWarpSize);
if (table_warp_id >= num_warps_per_row * B) {
return;
}
const uint32_t load_d = (table_warp_id % num_warps_per_row) * kWarpSize;
if (load_d + threadIdx.x < load_D) {
const uint32_t b = table_warp_id / num_warps_per_row;
const uint32_t total_load_D = D_offsets[T] / VEC_WIDTH;

output_vec_t* output_ptr = reinterpret_cast<output_vec_t*>(output) +
D_start + b * total_load_D + load_d + threadIdx.x;

// Write zeros to output
Vec4StepT<1, emb_t> accumulator;
accumulator.store(output_ptr);
}
return;
}
is_small_L = shfl_sync(is_small_L, 0);

// Use the small-L optimization if average L <= 8
const auto is_small_L = total_L <= (static_cast<index_t>(B) * 8);
const uint32_t num_warps_for_small_L = DIV_ROUND_UP(B, NUM_OFFSETS_PER_WARP);

// Early exit for small-L to avoid D_offsets reads
Expand Down Expand Up @@ -879,7 +902,8 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
if (L == 0) {
if (load_d + threadIdx.x < load_D) {
// Write zeros to output
memset(output_ptr, 0, sizeof(output_vec_t));
Vec4StepT<1, emb_t> accumulator;
accumulator.store(output_ptr);
}
}
else {
Expand Down Expand Up @@ -916,6 +940,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(

// Tail warp
// STEP_MASK computation assumes STEP = 4
{% if not weighted %}
if (load_D - load_d < kWarpSize) {
const auto tail_warp_size = load_D % kWarpSize;
if (tail_warp_size <= 8) {
Expand All @@ -931,6 +956,9 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
else {
INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf)
}
{% else %}
INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf)
{% endif %}

#undef INVOKE_PROCESS_ALL_INDICES_HELPER
#undef INVOKE_PROCESS_ALL_INDICES
Expand Down