Skip to content

Add optimized TBE training forward #1641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ set(gen_gpu_kernel_source_files
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu")

if(NOT USE_ROCM)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_forward_split_weighted_v2_kernel.cu"
"gen_embedding_forward_split_unweighted_v2_kernel.cu")
endif()

foreach(wdesc dense split)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_forward_${wdesc}_unweighted_nobag_kernel_small.cu")
Expand Down Expand Up @@ -316,6 +322,7 @@ set(codegen_dependencies
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.cpp
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.h
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_v2_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu
${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh
Expand Down
50 changes: 33 additions & 17 deletions fbgemm_gpu/codegen/embedding_backward_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,17 +1408,26 @@ def lars_sgd() -> None:
def generate_forward_embedding_cuda(
template_filepath: str,
filename_format: str,
dense_options: List[bool],
nobag_options: List[bool],
vbe_options: List[bool],
) -> None:
template = env.get_template(template_filepath)
for dense in [True, False]:
for dense in dense_options:
for weighted in [True, False]:
for nobag in [True, False]:
for vbe in [True, False]:
for nobag in nobag_options:
for vbe in vbe_options:
if (not nobag or (not weighted and not vbe)) and (
not dense or not vbe
):
wdesc = f"{ 'dense' if dense else 'split'}_{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }"
filename = filename_format.format(wdesc)
dense_desc = f"{ 'dense' if dense else 'split'}"
weight_desc = f"{ 'weighted' if weighted else 'unweighted' }"
nobag_desc = f"{ '_nobag' if nobag else '' }"
vbe_desc = f"{ '_vbe' if vbe else '' }"
desc = (
f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }"
)
filename = filename_format.format(desc)
write(
filename,
template.render(
Expand All @@ -1430,23 +1439,30 @@ def generate_forward_embedding_cuda(

def forward_split() -> None:
# Generate the forward splits
template = env.get_template("embedding_forward_split_template.cu")
for dense in [True, False]:
for weighted in [True, False]:
for vbe in [True, False]:
if not dense or not vbe:
wdesc = f"{ 'dense' if dense else 'split' }_{ 'weighted' if weighted else 'unweighted' }{ '_vbe' if vbe else '' }"
filename = f"gen_embedding_forward_{wdesc}_codegen_cuda.cu"
write(
filename,
template.render(weighted=weighted, dense=dense, vbe=vbe),
)
print(f"[Forward Split]: {filename}")
generate_forward_embedding_cuda(
"embedding_forward_split_template.cu",
"gen_embedding_forward_{}_codegen_cuda.cu",
dense_options=[True, False],
nobag_options=[False], # nobag is not used
vbe_options=[True, False],
)

# Generate the kernels for the forward splits
generate_forward_embedding_cuda(
"embedding_forward_split_kernel_template.cu",
"gen_embedding_forward_{}_kernel.cu",
dense_options=[True, False],
nobag_options=[True, False],
vbe_options=[True, False],
)

# Generate the kernels for the forward splits v2
generate_forward_embedding_cuda(
"embedding_forward_split_kernel_v2_template.cu",
"gen_embedding_forward_{}_v2_kernel.cu",
dense_options=[False], # dense is not supported
nobag_options=[False], # nobag is not supported
vbe_options=[False], # vbe is not supported
)

# Generate the small kernels (for nobag only) for the forward splits
Expand Down
23 changes: 12 additions & 11 deletions fbgemm_gpu/codegen/embedding_backward_dense_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Tensor dense_embedding_codegen_forward_unweighted_cuda(
Tensor offsets,
int64_t pooling_mode,
int64_t output_dtype,
int64_t BT_block_size);
bool is_experimental);

Tensor dense_embedding_codegen_forward_weighted_cuda(
Tensor dev_weights,
Expand All @@ -40,7 +40,7 @@ Tensor dense_embedding_codegen_forward_weighted_cuda(
int64_t pooling_mode,
Tensor indice_weights,
int64_t output_dtype,
int64_t BT_block_size);
bool is_experimental);

Tensor dense_embedding_codegen_grad_indice_weights_cuda(
Tensor grad_output,
Expand Down Expand Up @@ -117,11 +117,6 @@ class SplitLookupFunction_Dense_Op
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
ctx->saved_data["pooling_mode"] = pooling_mode;

#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights.has_value()) {
return {dense_embedding_codegen_forward_unweighted_cuda(
dev_weights,
Expand All @@ -133,7 +128,7 @@ class SplitLookupFunction_Dense_Op
offsets,
pooling_mode,
output_dtype,
BT_block_size)};
/*is_experimental=*/false)};
} else {
return {dense_embedding_codegen_forward_weighted_cuda(
dev_weights,
Expand All @@ -146,7 +141,7 @@ class SplitLookupFunction_Dense_Op
pooling_mode,
indice_weights.value(),
output_dtype,
BT_block_size)};
/*is_experimental=*/false)};
}
}

Expand Down Expand Up @@ -276,7 +271,7 @@ Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor indices,
Tensor offsets,
int64_t output_dtype,
int64_t unused);
bool is_experimental);

Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
Tensor grad_output,
Expand Down Expand Up @@ -316,7 +311,13 @@ class SplitNoBagLookupFunction_Dense_Op
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;

return {dense_embedding_nobag_codegen_forward_unweighted_cuda(
dev_weights, weights_offsets, D, indices, offsets, output_dtype, 0)};
dev_weights,
weights_offsets,
D,
indices,
offsets,
output_dtype,
/*is_experimental*/ false)};
}

static torch::autograd::variable_list backward(
Expand Down
50 changes: 21 additions & 29 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@ Tensor split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
Tensor lxu_cache_locations,
int64_t output_dtype,
{% if vbe %}
int64_t BT_block_size,
const VBEMetadata& vbe_metadata,
const int32_t info_B_num_bits,
const uint32_t info_B_mask
{% else %}
int64_t BT_block_size
const uint32_t info_B_mask,
{% endif %}
);
bool is_experimental);

Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
Tensor dev_weights,
Expand All @@ -64,14 +61,11 @@ Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
Tensor lxu_cache_locations,
int64_t output_dtype,
{% if vbe %}
int64_t BT_block_size,
const VBEMetadata& vbe_metadata,
const int32_t info_B_num_bits,
const uint32_t info_B_mask
{% else %}
int64_t BT_block_size
const uint32_t info_B_mask,
{% endif %}
);
bool is_experimental);

Tensor split_embedding_codegen_grad_indice_weights{{ vbe_desc }}_cuda(
Tensor grad_output,
Expand Down Expand Up @@ -158,8 +152,7 @@ Tensor split_embedding_nobag_codegen_forward_unweighted_cuda(
Tensor offsets,
Tensor lxu_cache_locations,
int64_t output_dtype,
int64_t unused
);
bool is_experimental);

void split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
Tensor grad_output,
Expand Down Expand Up @@ -226,6 +219,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
const int32_t max_B_feature_rank,
const int64_t vbe_output_size,
{% endif %}
bool is_experimental,
{{ args.split_function_args | join(", ") }}) {

const auto T = weights_offsets.numel();
Expand Down Expand Up @@ -307,11 +301,6 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
{% endfor %}

{% if not nobag %}
#ifdef __HIP_PLATFORM_HCC__
constexpr int32_t BT_block_size = 64;
#else
constexpr int32_t BT_block_size = 32;
#endif
if (!indice_weights) {
return {
split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
Expand All @@ -329,13 +318,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
lxu_cache_locations,
output_dtype,
{% if vbe %}
BT_block_size,
vbe_metadata,
info_B_num_bits,
info_B_mask
{% else %}
BT_block_size
info_B_mask,
{% endif %}
is_experimental
)
};
} else {
Expand All @@ -356,13 +343,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
lxu_cache_locations,
output_dtype,
{% if vbe %}
BT_block_size,
vbe_metadata,
info_B_num_bits,
info_B_mask
{% else %}
BT_block_size
info_B_mask,
{% endif %}
is_experimental
)
};
}
Expand All @@ -379,7 +364,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
offsets,
lxu_cache_locations,
output_dtype,
0
/*is_experimental=*/false
)
};
{% endif %}
Expand Down Expand Up @@ -521,6 +506,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
Variable(), // max_B_feature_rank
Variable(), // vbe_output_size
{% endif %}
Variable(), // is_experimental
{{ args.split_variables | join(", ") }}
};
} else {
Expand Down Expand Up @@ -601,6 +587,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
Variable(), // max_B_feature_rank
Variable(), // vbe_output_size
{% endif %}
Variable(), // is_experimental
{{ args.split_variables | join(", ") }}
};
}
Expand Down Expand Up @@ -652,6 +639,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
Variable(), // max_B_feature_rank
Variable(), // vbe_output_size
{% endif %}
Variable(), // is_experimental
{{ args.split_variables | join(", ") }}
};
{% endif %}
Expand Down Expand Up @@ -691,7 +679,9 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::optional<Tensor>(),
const int64_t max_B = -1,
const int64_t max_B_feature_rank = -1,
const int64_t vbe_output_size = -1) {
const int64_t vbe_output_size = -1,
const bool is_experimental = false
) {
{% if has_gpu_support %}
{% for vbe in ([True, False] if has_vbe_support else [False]) %}
{% set vbe_class_desc = "VBE" if vbe else "" %}
Expand Down Expand Up @@ -721,6 +711,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
gradient_clipping,
max_gradient,
stochastic_rounding,
is_experimental,
{{ args.split_function_arg_names | join(", ") }})[0];
} else {
return Split{{ vbe_class_desc }}LookupFunction_{{ optimizer }}_Op::apply(
Expand Down Expand Up @@ -753,6 +744,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
max_B_feature_rank,
vbe_output_size,
{% endif %}
is_experimental,
{{ args.split_function_arg_names | join(", ") }})[0];
}
{% if has_vbe_support %}
Expand All @@ -767,12 +759,12 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(

// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }

/*
Explicitly instantiate the kernel function template. The instantiations are
based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in
based on the types enumerated by DISPATCH_EMB_CACHE_TYPES macro used in
embedding_forward_split_template.cu
*/

Expand Down
Loading