Skip to content

Commit 75318b0

Browse files
sryapfacebook-github-bot
authored andcommitted
Add optimized TBE training forward (pytorch#1641)
Summary: Pull Request resolved: pytorch#1641 This diff adds an optimized implementation of TBE training forward, namely `split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`. The implementation currently supports only a subset of usecases of TBE including: - Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`) - Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`) - Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`) - FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`, `SparseType.FP16`) - FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`, `SparseType.FP16`) - Device, manged, managed caching embedding locations (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`, `EmbeddingLocation.MANAGED`, `EmbeddingLocation.MANAGED_CACHING`) Cases that the new implementation does **NOT** support: - Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`) - Sequence TBE (`pooling_mode`: `PoolingMode.NONE`) - FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`: `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`, `SparseType.INT2`, `SparseType.BF16`) - FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`: `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`, `SparseType.INT2`, `SparseType.BF16`) - Host embedding locations (`EmbeddingLocation`: `EmbeddingLocation.HOST`) Note that this optimization is enabled for NVIDIA GPUs, but **not** enabled for AMD GPUs. **Usage** The frontend changes are in D44479772 The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for enabling/disabling the new implementation at runtime. If `FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation. If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation. If the TBE usecases are not supported in the new implementation, TBE will fall back to the original implementation. By default, `FBGEMM_EXPERIMENTAL_TBE` is not set. This can also be enabled by passing `use_experimental_tbe=True` when instantiating the TBE operator. ``` emb_op = SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=..., ..., use_experimental_tbe=True, ) ``` **Optimization** The new implementation contains the following optimizations: - Use multiple warps per bag for D > 128 to maintain a constant number of registers per thread - Use subwarps to process subsets of input rows in a bag if D < 128 - Cooperatively compute weight pointers and store them in shared memory - Save state variables in shared memory instead of registers to free registers for compiler optimizations - Use the upper bound number of warps for all tables to avoid complex warp offset computation - Process multiple samples (up to kWarpSize samples) in a warp for small Ls Note: D = embedding dimension, L = pooling factor Reviewed By: jianyuh Differential Revision: D43634651 fbshipit-source-id: 96ad56f0e5567959fd28c72a649f862e1f5dd307
1 parent e9d7e3e commit 75318b0

8 files changed

+1541
-71
lines changed

fbgemm_gpu/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ set(gen_gpu_kernel_source_files
214214
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
215215
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu")
216216

217+
if(NOT USE_ROCM)
218+
list(APPEND gen_gpu_kernel_source_files
219+
"gen_embedding_forward_split_weighted_v2_kernel.cu"
220+
"gen_embedding_forward_split_unweighted_v2_kernel.cu")
221+
endif()
222+
217223
foreach(wdesc dense split)
218224
list(APPEND gen_gpu_kernel_source_files
219225
"gen_embedding_forward_${wdesc}_unweighted_nobag_kernel_small.cu")
@@ -316,6 +322,7 @@ set(codegen_dependencies
316322
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.cpp
317323
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.h
318324
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_template.cu
325+
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_v2_template.cu
319326
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu
320327
${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu
321328
${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,17 +1408,26 @@ def lars_sgd() -> None:
14081408
def generate_forward_embedding_cuda(
14091409
template_filepath: str,
14101410
filename_format: str,
1411+
dense_options: List[bool],
1412+
nobag_options: List[bool],
1413+
vbe_options: List[bool],
14111414
) -> None:
14121415
template = env.get_template(template_filepath)
1413-
for dense in [True, False]:
1416+
for dense in dense_options:
14141417
for weighted in [True, False]:
1415-
for nobag in [True, False]:
1416-
for vbe in [True, False]:
1418+
for nobag in nobag_options:
1419+
for vbe in vbe_options:
14171420
if (not nobag or (not weighted and not vbe)) and (
14181421
not dense or not vbe
14191422
):
1420-
wdesc = f"{ 'dense' if dense else 'split'}_{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }"
1421-
filename = filename_format.format(wdesc)
1423+
dense_desc = f"{ 'dense' if dense else 'split'}"
1424+
weight_desc = f"{ 'weighted' if weighted else 'unweighted' }"
1425+
nobag_desc = f"{ '_nobag' if nobag else '' }"
1426+
vbe_desc = f"{ '_vbe' if vbe else '' }"
1427+
desc = (
1428+
f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }"
1429+
)
1430+
filename = filename_format.format(desc)
14221431
write(
14231432
filename,
14241433
template.render(
@@ -1430,23 +1439,30 @@ def generate_forward_embedding_cuda(
14301439

14311440
def forward_split() -> None:
14321441
# Generate the forward splits
1433-
template = env.get_template("embedding_forward_split_template.cu")
1434-
for dense in [True, False]:
1435-
for weighted in [True, False]:
1436-
for vbe in [True, False]:
1437-
if not dense or not vbe:
1438-
wdesc = f"{ 'dense' if dense else 'split' }_{ 'weighted' if weighted else 'unweighted' }{ '_vbe' if vbe else '' }"
1439-
filename = f"gen_embedding_forward_{wdesc}_codegen_cuda.cu"
1440-
write(
1441-
filename,
1442-
template.render(weighted=weighted, dense=dense, vbe=vbe),
1443-
)
1444-
print(f"[Forward Split]: {filename}")
1442+
generate_forward_embedding_cuda(
1443+
"embedding_forward_split_template.cu",
1444+
"gen_embedding_forward_{}_codegen_cuda.cu",
1445+
dense_options=[True, False],
1446+
nobag_options=[False], # nobag is not used
1447+
vbe_options=[True, False],
1448+
)
14451449

14461450
# Generate the kernels for the forward splits
14471451
generate_forward_embedding_cuda(
14481452
"embedding_forward_split_kernel_template.cu",
14491453
"gen_embedding_forward_{}_kernel.cu",
1454+
dense_options=[True, False],
1455+
nobag_options=[True, False],
1456+
vbe_options=[True, False],
1457+
)
1458+
1459+
# Generate the kernels for the forward splits v2
1460+
generate_forward_embedding_cuda(
1461+
"embedding_forward_split_kernel_v2_template.cu",
1462+
"gen_embedding_forward_{}_v2_kernel.cu",
1463+
dense_options=[False], # dense is not supported
1464+
nobag_options=[False], # nobag is not supported
1465+
vbe_options=[False], # vbe is not supported
14501466
)
14511467

14521468
# Generate the small kernels (for nobag only) for the forward splits

fbgemm_gpu/codegen/embedding_backward_dense_host.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Tensor dense_embedding_codegen_forward_unweighted_cuda(
2727
Tensor offsets,
2828
int64_t pooling_mode,
2929
int64_t output_dtype,
30-
int64_t BT_block_size);
30+
bool is_experimental);
3131

3232
Tensor dense_embedding_codegen_forward_weighted_cuda(
3333
Tensor dev_weights,
@@ -40,7 +40,7 @@ Tensor dense_embedding_codegen_forward_weighted_cuda(
4040
int64_t pooling_mode,
4141
Tensor indice_weights,
4242
int64_t output_dtype,
43-
int64_t BT_block_size);
43+
bool is_experimental);
4444

4545
Tensor dense_embedding_codegen_grad_indice_weights_cuda(
4646
Tensor grad_output,
@@ -117,11 +117,6 @@ class SplitLookupFunction_Dense_Op
117117
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
118118
ctx->saved_data["pooling_mode"] = pooling_mode;
119119

120-
#ifdef __HIP_PLATFORM_HCC__
121-
constexpr int32_t BT_block_size = 64;
122-
#else
123-
constexpr int32_t BT_block_size = 32;
124-
#endif
125120
if (!indice_weights.has_value()) {
126121
return {dense_embedding_codegen_forward_unweighted_cuda(
127122
dev_weights,
@@ -133,7 +128,7 @@ class SplitLookupFunction_Dense_Op
133128
offsets,
134129
pooling_mode,
135130
output_dtype,
136-
BT_block_size)};
131+
/*is_experimental=*/false)};
137132
} else {
138133
return {dense_embedding_codegen_forward_weighted_cuda(
139134
dev_weights,
@@ -146,7 +141,7 @@ class SplitLookupFunction_Dense_Op
146141
pooling_mode,
147142
indice_weights.value(),
148143
output_dtype,
149-
BT_block_size)};
144+
/*is_experimental=*/false)};
150145
}
151146
}
152147

@@ -276,7 +271,7 @@ Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda(
276271
Tensor indices,
277272
Tensor offsets,
278273
int64_t output_dtype,
279-
int64_t unused);
274+
bool is_experimental);
280275

281276
Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
282277
Tensor grad_output,
@@ -316,7 +311,13 @@ class SplitNoBagLookupFunction_Dense_Op
316311
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
317312

318313
return {dense_embedding_nobag_codegen_forward_unweighted_cuda(
319-
dev_weights, weights_offsets, D, indices, offsets, output_dtype, 0)};
314+
dev_weights,
315+
weights_offsets,
316+
D,
317+
indices,
318+
offsets,
319+
output_dtype,
320+
/*is_experimental*/ false)};
320321
}
321322

322323
static torch::autograd::variable_list backward(

fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,11 @@ Tensor split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
3939
Tensor lxu_cache_locations,
4040
int64_t output_dtype,
4141
{% if vbe %}
42-
int64_t BT_block_size,
4342
const VBEMetadata& vbe_metadata,
4443
const int32_t info_B_num_bits,
45-
const uint32_t info_B_mask
46-
{% else %}
47-
int64_t BT_block_size
44+
const uint32_t info_B_mask,
4845
{% endif %}
49-
);
46+
bool is_experimental);
5047

5148
Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
5249
Tensor dev_weights,
@@ -64,14 +61,11 @@ Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
6461
Tensor lxu_cache_locations,
6562
int64_t output_dtype,
6663
{% if vbe %}
67-
int64_t BT_block_size,
6864
const VBEMetadata& vbe_metadata,
6965
const int32_t info_B_num_bits,
70-
const uint32_t info_B_mask
71-
{% else %}
72-
int64_t BT_block_size
66+
const uint32_t info_B_mask,
7367
{% endif %}
74-
);
68+
bool is_experimental);
7569

7670
Tensor split_embedding_codegen_grad_indice_weights{{ vbe_desc }}_cuda(
7771
Tensor grad_output,
@@ -158,8 +152,7 @@ Tensor split_embedding_nobag_codegen_forward_unweighted_cuda(
158152
Tensor offsets,
159153
Tensor lxu_cache_locations,
160154
int64_t output_dtype,
161-
int64_t unused
162-
);
155+
bool is_experimental);
163156

164157
void split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
165158
Tensor grad_output,
@@ -226,6 +219,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
226219
const int32_t max_B_feature_rank,
227220
const int64_t vbe_output_size,
228221
{% endif %}
222+
bool is_experimental,
229223
{{ args.split_function_args | join(", ") }}) {
230224

231225
const auto T = weights_offsets.numel();
@@ -307,11 +301,6 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
307301
{% endfor %}
308302

309303
{% if not nobag %}
310-
#ifdef __HIP_PLATFORM_HCC__
311-
constexpr int32_t BT_block_size = 64;
312-
#else
313-
constexpr int32_t BT_block_size = 32;
314-
#endif
315304
if (!indice_weights) {
316305
return {
317306
split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
@@ -329,13 +318,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
329318
lxu_cache_locations,
330319
output_dtype,
331320
{% if vbe %}
332-
BT_block_size,
333321
vbe_metadata,
334322
info_B_num_bits,
335-
info_B_mask
336-
{% else %}
337-
BT_block_size
323+
info_B_mask,
338324
{% endif %}
325+
is_experimental
339326
)
340327
};
341328
} else {
@@ -356,13 +343,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
356343
lxu_cache_locations,
357344
output_dtype,
358345
{% if vbe %}
359-
BT_block_size,
360346
vbe_metadata,
361347
info_B_num_bits,
362-
info_B_mask
363-
{% else %}
364-
BT_block_size
348+
info_B_mask,
365349
{% endif %}
350+
is_experimental
366351
)
367352
};
368353
}
@@ -379,7 +364,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
379364
offsets,
380365
lxu_cache_locations,
381366
output_dtype,
382-
0
367+
/*is_experimental=*/false
383368
)
384369
};
385370
{% endif %}
@@ -521,6 +506,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
521506
Variable(), // max_B_feature_rank
522507
Variable(), // vbe_output_size
523508
{% endif %}
509+
Variable(), // is_experimental
524510
{{ args.split_variables | join(", ") }}
525511
};
526512
} else {
@@ -601,6 +587,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
601587
Variable(), // max_B_feature_rank
602588
Variable(), // vbe_output_size
603589
{% endif %}
590+
Variable(), // is_experimental
604591
{{ args.split_variables | join(", ") }}
605592
};
606593
}
@@ -652,6 +639,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
652639
Variable(), // max_B_feature_rank
653640
Variable(), // vbe_output_size
654641
{% endif %}
642+
Variable(), // is_experimental
655643
{{ args.split_variables | join(", ") }}
656644
};
657645
{% endif %}
@@ -691,7 +679,9 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
691679
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::optional<Tensor>(),
692680
const int64_t max_B = -1,
693681
const int64_t max_B_feature_rank = -1,
694-
const int64_t vbe_output_size = -1) {
682+
const int64_t vbe_output_size = -1,
683+
const bool is_experimental = false
684+
) {
695685
{% if has_gpu_support %}
696686
{% for vbe in ([True, False] if has_vbe_support else [False]) %}
697687
{% set vbe_class_desc = "VBE" if vbe else "" %}
@@ -721,6 +711,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
721711
gradient_clipping,
722712
max_gradient,
723713
stochastic_rounding,
714+
is_experimental,
724715
{{ args.split_function_arg_names | join(", ") }})[0];
725716
} else {
726717
return Split{{ vbe_class_desc }}LookupFunction_{{ optimizer }}_Op::apply(
@@ -753,6 +744,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
753744
max_B_feature_rank,
754745
vbe_output_size,
755746
{% endif %}
747+
is_experimental,
756748
{{ args.split_function_arg_names | join(", ") }})[0];
757749
}
758750
{% if has_vbe_support %}
@@ -767,12 +759,12 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
767759

768760
// Deprecated for fb namespace! Please use fbgemm namespace instead!
769761
TORCH_LIBRARY_FRAGMENT(fb, m) {
770-
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
762+
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
771763
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
772764
}
773765

774766
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
775-
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
767+
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
776768
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
777769
}
778770

fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }
278278

279279
/*
280280
Explicitly instantiate the kernel function template. The instantiations are
281-
based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in
281+
based on the types enumerated by DISPATCH_EMB_CACHE_TYPES macro used in
282282
embedding_forward_split_template.cu
283283
*/
284284

0 commit comments

Comments
 (0)