Skip to content

Commit b163a5e

Browse files
sryapfacebook-github-bot
authored andcommitted
Add optimized TBE training forward (#1641)
Summary: Pull Request resolved: #1641 This diff adds an optimized implementation of TBE training forward, namely `split_embedding_codegen_forward_[weighted|unweighted]_v2_kernel`. The implementation currently supports only a subset of usecases of TBE including: - Split TBE (`SplitTableBatchedEmbeddingBagsCodegen`) - Pooled TBE (`pooling_mode`: `PoolingMode.SUM`, `PoolingMode.MEAN`) - Weighted and unweighted TBE (`per_sample_weights`: `Tensor`, `None`) - FP32 and FP16 weight types (`weights_precision`: `SparseType.FP32`, `SparseType.FP16`) - FP32 and FP16 output types (`output_dtype`: `SparseType.FP32`, `SparseType.FP16`) - Device, manged, managed caching embedding locations (`EmbeddingLocation`: `EmbeddingLocation.DEVICE`, `EmbeddingLocation.MANAGED`, `EmbeddingLocation.MANAGED_CACHING`) Cases that the new implementation does **NOT** support: - Dense TBE (`DenseTableBatchedEmbeddingBagsCodegen`) - Sequence TBE (`pooling_mode`: `PoolingMode.NONE`) - FP8, INT8, INT4, INT2, and BF16 weight types (`weights_precision`: `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`, `SparseType.INT2`, `SparseType.BF16`) - FP8, INT8, INT4, INT2, and BF16 output types (`weights_precision`: `SparseType.FP8`, `SparseType.INT8`, `SparseType.INT4`, `SparseType.INT2`, `SparseType.BF16`) - Host embedding locations (`EmbeddingLocation`: `EmbeddingLocation.HOST`) The `FBGEMM_EXPERIMENTAL_TBE` environment variable flag is added for enabling/disabling the new implementation at runtime. If `FBGEMM_EXPERIMENTAL_TBE` is not set, TBE will use the orignal implementation. If `FBGEMM_EXPERIMENTAL_TBE=1`, TBE will use the new implementation. If the TBE usecases are not supported in the new implementation, TBE will fall back to the original implementation. By default, `FBGEMM_EXPERIMENTAL_TBE` is not set. The new implementation contains the following optimizations: - Use multiple warps per bag for D > 128 to maintain a constant number of registers per thread - Use subwarps to process subsets of input rows in a bag if D < 128 - Cooperatively compute weight pointers and store them in shared memory - Save state variables in shared memory instead of registers to free registers for compiler optimizations - Use the upper bound number of warps for all tables to avoid complex warp offset computation - Process multiple samples (up to kWarpSize samples) in a warp for small Ls Note: D = embedding dimension, L = pooling factor Differential Revision: D43634651 fbshipit-source-id: 42b8c5b853dd30df9bb3b2f808668d1ebf0db9a7
1 parent 32a9e37 commit b163a5e

8 files changed

+1503
-72
lines changed

fbgemm_gpu/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,9 @@ set(gen_gpu_kernel_source_files
206206
"gen_embedding_backward_split_indice_weights_codegen_cuda.cu"
207207
"gen_embedding_backward_split_grad.cu"
208208
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
209-
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu")
209+
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu"
210+
"gen_embedding_forward_split_weighted_v2_kernel.cu"
211+
"gen_embedding_forward_split_unweighted_v2_kernel.cu")
210212

211213
foreach(wdesc dense split)
212214
list(APPEND gen_gpu_kernel_source_files
@@ -310,6 +312,7 @@ set(codegen_dependencies
310312
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.cpp
311313
${CMAKE_CODEGEN_DIR}/embedding_forward_split_cpu.h
312314
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_template.cu
315+
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_v2_template.cu
313316
${CMAKE_CODEGEN_DIR}/embedding_forward_split_kernel_nobag_small_template.cu
314317
${CMAKE_CODEGEN_DIR}/embedding_forward_split_template.cu
315318
${CMAKE_CODEGEN_DIR}/embedding_forward_template_helpers.cuh

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,17 +1353,26 @@ def lars_sgd() -> None:
13531353
def generate_forward_embedding_cuda(
13541354
template_filepath: str,
13551355
filename_format: str,
1356+
dense_options: List[bool],
1357+
nobag_options: List[bool],
1358+
vbe_options: List[bool],
13561359
) -> None:
13571360
template = env.get_template(template_filepath)
1358-
for dense in [True, False]:
1361+
for dense in dense_options:
13591362
for weighted in [True, False]:
1360-
for nobag in [True, False]:
1361-
for vbe in [True, False]:
1363+
for nobag in nobag_options:
1364+
for vbe in vbe_options:
13621365
if (not nobag or (not weighted and not vbe)) and (
13631366
not dense or not vbe
13641367
):
1365-
wdesc = f"{ 'dense' if dense else 'split'}_{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }"
1366-
filename = filename_format.format(wdesc)
1368+
dense_desc = f"{ 'dense' if dense else 'split'}"
1369+
weight_desc = f"{ 'weighted' if weighted else 'unweighted' }"
1370+
nobag_desc = f"{ '_nobag' if nobag else '' }"
1371+
vbe_desc = f"{ '_vbe' if vbe else '' }"
1372+
desc = (
1373+
f"{ dense_desc }_{ weight_desc }{ nobag_desc }{ vbe_desc }"
1374+
)
1375+
filename = filename_format.format(desc)
13671376
write(
13681377
filename,
13691378
template.render(
@@ -1375,23 +1384,30 @@ def generate_forward_embedding_cuda(
13751384

13761385
def forward_split() -> None:
13771386
# Generate the forward splits
1378-
template = env.get_template("embedding_forward_split_template.cu")
1379-
for dense in [True, False]:
1380-
for weighted in [True, False]:
1381-
for vbe in [True, False]:
1382-
if not dense or not vbe:
1383-
wdesc = f"{ 'dense' if dense else 'split' }_{ 'weighted' if weighted else 'unweighted' }{ '_vbe' if vbe else '' }"
1384-
filename = f"gen_embedding_forward_{wdesc}_codegen_cuda.cu"
1385-
write(
1386-
filename,
1387-
template.render(weighted=weighted, dense=dense, vbe=vbe),
1388-
)
1389-
print(f"[Forward Split]: {filename}")
1387+
generate_forward_embedding_cuda(
1388+
"embedding_forward_split_template.cu",
1389+
"gen_embedding_forward_{}_codegen_cuda.cu",
1390+
dense_options=[True, False],
1391+
nobag_options=[False], # nobag is not used
1392+
vbe_options=[True, False],
1393+
)
13901394

13911395
# Generate the kernels for the forward splits
13921396
generate_forward_embedding_cuda(
13931397
"embedding_forward_split_kernel_template.cu",
13941398
"gen_embedding_forward_{}_kernel.cu",
1399+
dense_options=[True, False],
1400+
nobag_options=[True, False],
1401+
vbe_options=[True, False],
1402+
)
1403+
1404+
# Generate the kernels for the forward splits v2
1405+
generate_forward_embedding_cuda(
1406+
"embedding_forward_split_kernel_v2_template.cu",
1407+
"gen_embedding_forward_{}_v2_kernel.cu",
1408+
dense_options=[False], # dense is not supported
1409+
nobag_options=[False], # nobag is not supported
1410+
vbe_options=[False], # vbe is not supported
13951411
)
13961412

13971413
# Generate the small kernels (for nobag only) for the forward splits

fbgemm_gpu/codegen/embedding_backward_dense_host.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Tensor dense_embedding_codegen_forward_unweighted_cuda(
2727
Tensor offsets,
2828
int64_t pooling_mode,
2929
int64_t output_dtype,
30-
int64_t BT_block_size);
30+
bool is_experimental);
3131

3232
Tensor dense_embedding_codegen_forward_weighted_cuda(
3333
Tensor dev_weights,
@@ -40,7 +40,7 @@ Tensor dense_embedding_codegen_forward_weighted_cuda(
4040
int64_t pooling_mode,
4141
Tensor indice_weights,
4242
int64_t output_dtype,
43-
int64_t BT_block_size);
43+
bool is_experimental);
4444

4545
Tensor dense_embedding_codegen_grad_indice_weights_cuda(
4646
Tensor grad_output,
@@ -117,11 +117,6 @@ class SplitLookupFunction_Dense_Op
117117
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
118118
ctx->saved_data["pooling_mode"] = pooling_mode;
119119

120-
#ifdef __HIP_PLATFORM_HCC__
121-
constexpr int32_t BT_block_size = 64;
122-
#else
123-
constexpr int32_t BT_block_size = 32;
124-
#endif
125120
if (!indice_weights.has_value()) {
126121
return {dense_embedding_codegen_forward_unweighted_cuda(
127122
dev_weights,
@@ -133,7 +128,7 @@ class SplitLookupFunction_Dense_Op
133128
offsets,
134129
pooling_mode,
135130
output_dtype,
136-
BT_block_size)};
131+
/*is_experimental=*/false)};
137132
} else {
138133
return {dense_embedding_codegen_forward_weighted_cuda(
139134
dev_weights,
@@ -146,7 +141,7 @@ class SplitLookupFunction_Dense_Op
146141
pooling_mode,
147142
indice_weights.value(),
148143
output_dtype,
149-
BT_block_size)};
144+
/*is_experimental=*/false)};
150145
}
151146
}
152147

@@ -276,7 +271,7 @@ Tensor dense_embedding_nobag_codegen_forward_unweighted_cuda(
276271
Tensor indices,
277272
Tensor offsets,
278273
int64_t output_dtype,
279-
int64_t unused);
274+
bool is_experimental);
280275

281276
Tensor split_embedding_nobag_backward_codegen_dense_unweighted_exact_cuda(
282277
Tensor grad_output,
@@ -316,7 +311,13 @@ class SplitNoBagLookupFunction_Dense_Op
316311
ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits;
317312

318313
return {dense_embedding_nobag_codegen_forward_unweighted_cuda(
319-
dev_weights, weights_offsets, D, indices, offsets, output_dtype, 0)};
314+
dev_weights,
315+
weights_offsets,
316+
D,
317+
indices,
318+
offsets,
319+
output_dtype,
320+
/*is_experimental*/ false)};
320321
}
321322

322323
static torch::autograd::variable_list backward(

fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,11 @@ Tensor split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
3939
Tensor lxu_cache_locations,
4040
int64_t output_dtype,
4141
{% if vbe %}
42-
int64_t BT_block_size,
4342
const VBEMetadata& vbe_metadata,
4443
const int32_t info_B_num_bits,
45-
const uint32_t info_B_mask
46-
{% else %}
47-
int64_t BT_block_size
44+
const uint32_t info_B_mask,
4845
{% endif %}
49-
);
46+
bool is_experimental);
5047

5148
Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
5249
Tensor dev_weights,
@@ -64,14 +61,11 @@ Tensor split_embedding_codegen_forward_weighted{{ vbe_desc }}_cuda(
6461
Tensor lxu_cache_locations,
6562
int64_t output_dtype,
6663
{% if vbe %}
67-
int64_t BT_block_size,
6864
const VBEMetadata& vbe_metadata,
6965
const int32_t info_B_num_bits,
70-
const uint32_t info_B_mask
71-
{% else %}
72-
int64_t BT_block_size
66+
const uint32_t info_B_mask,
7367
{% endif %}
74-
);
68+
bool is_experimental);
7569

7670
Tensor split_embedding_codegen_grad_indice_weights{{ vbe_desc }}_cuda(
7771
Tensor grad_output,
@@ -158,8 +152,7 @@ Tensor split_embedding_nobag_codegen_forward_unweighted_cuda(
158152
Tensor offsets,
159153
Tensor lxu_cache_locations,
160154
int64_t output_dtype,
161-
int64_t unused
162-
);
155+
bool is_experimental);
163156

164157
void split_embedding_nobag_backward_codegen_{{ optimizer }}_unweighted_exact_cuda(
165158
Tensor grad_output,
@@ -226,6 +219,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
226219
const int32_t max_B_feature_rank,
227220
const int64_t vbe_output_size,
228221
{% endif %}
222+
bool is_experimental,
229223
{{ args.split_function_args | join(", ") }}) {
230224

231225
const auto T = weights_offsets.numel();
@@ -307,11 +301,6 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
307301
{% endfor %}
308302

309303
{% if not nobag %}
310-
#ifdef __HIP_PLATFORM_HCC__
311-
constexpr int32_t BT_block_size = 64;
312-
#else
313-
constexpr int32_t BT_block_size = 32;
314-
#endif
315304
if (!indice_weights) {
316305
return {
317306
split_embedding_codegen_forward_unweighted{{ vbe_desc }}_cuda(
@@ -329,13 +318,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
329318
lxu_cache_locations,
330319
output_dtype,
331320
{% if vbe %}
332-
BT_block_size,
333321
vbe_metadata,
334322
info_B_num_bits,
335-
info_B_mask
336-
{% else %}
337-
BT_block_size
323+
info_B_mask,
338324
{% endif %}
325+
is_experimental
339326
)
340327
};
341328
} else {
@@ -356,13 +343,11 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
356343
lxu_cache_locations,
357344
output_dtype,
358345
{% if vbe %}
359-
BT_block_size,
360346
vbe_metadata,
361347
info_B_num_bits,
362-
info_B_mask
363-
{% else %}
364-
BT_block_size
348+
info_B_mask,
365349
{% endif %}
350+
is_experimental
366351
)
367352
};
368353
}
@@ -379,7 +364,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
379364
offsets,
380365
lxu_cache_locations,
381366
output_dtype,
382-
0
367+
/*is_experimental=*/false
383368
)
384369
};
385370
{% endif %}
@@ -521,6 +506,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
521506
Variable(), // max_B_feature_rank
522507
Variable(), // vbe_output_size
523508
{% endif %}
509+
Variable(), // is_experimental
524510
{{ args.split_variables | join(", ") }}
525511
};
526512
} else {
@@ -601,6 +587,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
601587
Variable(), // max_B_feature_rank
602588
Variable(), // vbe_output_size
603589
{% endif %}
590+
Variable(), // is_experimental
604591
{{ args.split_variables | join(", ") }}
605592
};
606593
}
@@ -652,6 +639,7 @@ class Split{{ "NoBag" if nobag else "" }}{{ "VBE" if vbe else "" }}LookupFunctio
652639
Variable(), // max_B_feature_rank
653640
Variable(), // vbe_output_size
654641
{% endif %}
642+
Variable(), // is_experimental
655643
{{ args.split_variables | join(", ") }}
656644
};
657645
{% endif %}
@@ -691,7 +679,9 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
691679
const c10::optional<Tensor>& vbe_B_offsets_rank_per_feature = c10::optional<Tensor>(),
692680
const int64_t max_B = -1,
693681
const int64_t max_B_feature_rank = -1,
694-
const int64_t vbe_output_size = -1) {
682+
const int64_t vbe_output_size = -1,
683+
const bool is_experimental = false
684+
) {
695685
{% if has_gpu_support %}
696686
{% for vbe in ([True, False] if has_vbe_support else [False]) %}
697687
{% set vbe_class_desc = "VBE" if vbe else "" %}
@@ -721,6 +711,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
721711
gradient_clipping,
722712
max_gradient,
723713
stochastic_rounding,
714+
is_experimental,
724715
{{ args.split_function_arg_names | join(", ") }})[0];
725716
} else {
726717
return Split{{ vbe_class_desc }}LookupFunction_{{ optimizer }}_Op::apply(
@@ -753,6 +744,7 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
753744
max_B_feature_rank,
754745
vbe_output_size,
755746
{% endif %}
747+
is_experimental,
756748
{{ args.split_function_arg_names | join(", ") }})[0];
757749
}
758750
{% if has_vbe_support %}
@@ -767,12 +759,12 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
767759

768760
// Deprecated for fb namespace! Please use fbgemm namespace instead!
769761
TORCH_LIBRARY_FRAGMENT(fb, m) {
770-
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
762+
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
771763
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
772764
}
773765

774766
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
775-
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1) -> Tensor");
767+
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function(Tensor placeholder_autograd_tensor, Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0, Tensor? B_offsets=None, Tensor? vbe_output_offsets_feature_rank=None, Tensor? vbe_B_offsets_rank_per_feature=None, int max_B=-1, int max_B_feature_rank=-1, int vbe_output_size=-1, bool is_experimental=False) -> Tensor");
776768
DISPATCH_TO_CUDA("split_embedding_codegen_lookup_{{ optimizer }}_function", split_embedding_codegen_lookup_{{ optimizer }}_function);
777769
}
778770

fbgemm_gpu/codegen/embedding_forward_split_kernel_template.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ void {{ "dense" if dense else "split" }}_embedding{{ "_nobag" if nobag else "" }
278278

279279
/*
280280
Explicitly instantiate the kernel function template. The instantiations are
281-
based on the types enumerated by DISPATCH_EMB_GRAD_CACHE_TYPES macro used in
281+
based on the types enumerated by DISPATCH_EMB_CACHE_TYPES macro used in
282282
embedding_forward_split_template.cu
283283
*/
284284

0 commit comments

Comments
 (0)