Skip to content

Commit f8943cc

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (1/N) (pytorch#3324)
Summary: X-link: facebookresearch/FBGEMM#418 - Add `index_t` support to TBE training forward kernels Differential Revision: D65457179
1 parent 7d76ba7 commit f8943cc

File tree

6 files changed

+79
-51
lines changed

6 files changed

+79
-51
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "fbgemm_gpu/utils/ops_utils.h"
1717
#ifdef FBCODE_CAFFE2
1818
#include <libdivide.h>
19-
#include "folly/container/F14Map.h"
2019
#else
2120
#include <omp.h>
2221
#endif
@@ -29,7 +28,12 @@
2928
using Tensor = at::Tensor;
3029
using namespace fbgemm_gpu;
3130

32-
template <typename weights_t, typename ind_weights_t, typename output_t>
31+
template <
32+
typename weights_t,
33+
typename ind_weights_t,
34+
typename index_t,
35+
typename offset_t,
36+
typename output_t>
3337
void split_embedding_forward_cpu_kernel(
3438
Tensor weights,
3539
Tensor weights_offsets,
@@ -56,8 +60,8 @@ void split_embedding_forward_cpu_kernel(
5660

5761
const auto D_offsets_data = D_offsets.accessor<int, 1>();
5862
const auto weights_offsets_data = weights_offsets.accessor<int64_t, 1>();
59-
const auto indices_data = indices.data_ptr<int64_t>();
60-
const auto offsets_data = offsets.data_ptr<int64_t>();
63+
const auto indices_data = indices.data_ptr<index_t>();
64+
const auto offsets_data = offsets.data_ptr<offset_t>();
6165
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
6266

6367
const auto weights_data = weights.data_ptr<weights_t>();
@@ -97,8 +101,8 @@ void split_embedding_forward_cpu_kernel(
97101
weights_t>::type;
98102
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
99103
fbgemm_weight_t,
100-
/*IndexType=*/int64_t,
101-
/*OffsetType=*/int64_t>(
104+
/*IndexType=*/index_t,
105+
/*OffsetType=*/offset_t>(
102106
D,
103107
indice_weights.defined(),
104108
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
@@ -203,29 +207,44 @@ Tensor split_embedding_codegen_forward_cpu(
203207
// It is assumed that the indice_weights will always be float
204208
TORCH_CHECK(
205209
!indice_weights.defined() || indice_weights.scalar_type() != at::kHalf);
210+
206211
FBGEMM_DISPATCH_FLOAT_AND_HALF(
207-
output.scalar_type(), "split_embedding_cpu_forward", [&]() {
212+
output.scalar_type(), "split_embedding_cpu_forward_1", [&]() {
208213
using output_t = scalar_t;
214+
209215
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
210-
weights.scalar_type(), "split_embedding_cpu_forward", [&] {
216+
weights.scalar_type(), "split_embedding_cpu_forward_2", [&] {
211217
using ind_weights_t = std::conditional<
212218
std::is_same<scalar_t, double>::value,
213219
double,
214220
float>::type;
215-
split_embedding_forward_cpu_kernel<
216-
scalar_t,
217-
ind_weights_t,
218-
output_t>(
219-
weights,
220-
weights_offsets,
221-
D_offsets,
222-
total_D,
223-
hash_size_cumsum,
224-
indices,
225-
offsets,
226-
pooling_mode,
227-
indice_weights,
228-
output);
221+
222+
AT_DISPATCH_INDEX_TYPES(
223+
offsets.scalar_type(), "split_embedding_cpu_forward_3", [&] {
224+
using offset_t = index_t;
225+
226+
AT_DISPATCH_INDEX_TYPES(
227+
indices.scalar_type(),
228+
"split_embedding_cpu_forward_4",
229+
[&] {
230+
split_embedding_forward_cpu_kernel<
231+
scalar_t,
232+
ind_weights_t,
233+
index_t,
234+
offset_t,
235+
output_t>(
236+
weights,
237+
weights_offsets,
238+
D_offsets,
239+
total_D,
240+
hash_size_cumsum,
241+
indices,
242+
offsets,
243+
pooling_mode,
244+
indice_weights,
245+
output);
246+
});
247+
});
229248
});
230249
});
231250
return output;

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_nobag_small_template.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ batch_index_select_dim0_codegen_forward_small_kernel(
9595
indices_start = total_L_start + L_start;
9696
L = (total_L - L_start >= fixed_L_per_warp) ? fixed_L_per_warp : (total_L - L_start);
9797
{%- else %}
98-
index_t indices_start = offsets[b_t];
99-
int32_t L = offsets[b_t + 1] - indices_start;
98+
const auto indices_start = offsets[b_t];
99+
const auto L = offsets[b_t + 1] - indices_start;
100100
{%- endif %}
101101

102102
{%- if is_index_select %}
103-
const int32_t D_start = D_offsets[t];
104-
const int32_t D_end = D_offsets[t + 1];
105-
const int32_t D = D_end - D_start;
103+
const auto D_start = D_offsets[t];
104+
const auto D_end = D_offsets[t + 1];
105+
const auto D = D_end - D_start;
106106

107107
// Check D in the kernel to avoid iterating through the list on host
108108
CUDA_KERNEL_ASSERT(D % 4 == 0 && "The column size must be multiple of 4");
@@ -221,7 +221,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
221221
{%- for emb_type in ['float', 'at::Half'] %}
222222
{%- for cache_type in ['float', 'at::Half'] %}
223223
{%- for kEmbeddingSize in [4, 8, 16, 32] %}
224-
{%- set index_type = 'int64_t' %}
224+
{%- for index_type in ['int32_t', 'int64_t'] %}
225225

226226
template __launch_bounds__(kForwardMaxThreads) __global__ void
227227
{%- if is_index_select %}
@@ -268,3 +268,4 @@ batch_index_select_dim0_codegen_forward_small_kernel
268268
{%- endfor %}
269269
{%- endfor %}
270270
{%- endfor %}
271+
{%- endfor %}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ batch_index_select_dim0_codegen_forward_kernel(
560560
emb_type,
561561
cache_type,
562562
output_type,
563+
index_type,
563564
use_cache,
564565
kMaxVecsPerThread,
565566
kThreadGroupSize)
@@ -577,7 +578,7 @@ batch_index_select_dim0_codegen_forward_kernel
577578
{%- if not dense %}
578579
{{ use_cache }},
579580
{%- endif %}
580-
int64_t,
581+
{{ index_type }},
581582
{%- if not nobag %}
582583
{{ kMaxVecsPerThread }},
583584
{%- endif %}
@@ -603,9 +604,9 @@ batch_index_select_dim0_codegen_forward_kernel
603604
{%- else %}
604605
FixedDivisor fd_B,
605606
{%- endif %}
606-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices,
607+
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> indices,
607608
{%- if not is_index_select %}
608-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets,
609+
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> offsets,
609610
{%- endif %}
610611
{%- if not nobag %}
611612
int64_t pooling_mode,
@@ -638,17 +639,20 @@ batch_index_select_dim0_codegen_forward_kernel
638639
{%- for emb_type in ['float', 'at::Half'] %}
639640
{%- for cache_type in ['float', 'at::Half'] %}
640641
{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}
642+
{%- for index_type in ['int32_t', 'int64_t'] %}
641643
{{ template_instantiation(
642644
emb_type,
643645
cache_type,
644646
output_type,
647+
index_type,
645648
use_cache,
646649
kMaxVecsPerThread,
647650
kThreadGroupSize)
648651
}}
649652
{%- endfor %}
650653
{%- endfor %}
651654
{%- endfor %}
655+
{%- endfor %}
652656
{%- endmacro %}
653657

654658
{%- macro instantiate_templates(use_subwarp_shuffle) %}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
986986
*/
987987

988988
{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}
989+
{%- for index_type in ['int32_t', 'int64_t'] %}
989990
{%- for emb_type in ['float', 'at::Half'] %}
990991
{%- for cache_type in ['float', 'at::Half'] %}
991992
{%- for use_cache in ['true', 'false'] %}
@@ -996,7 +997,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
996997
{{ emb_type }},
997998
{{ cache_type }},
998999
{{ output_type }},
999-
int64_t, // index_t
1000+
{{ index_type }},
10001001
{{ use_cache }}
10011002
> (
10021003
const {{ emb_type }}* __restrict__ const dev_weights,
@@ -1008,11 +1009,11 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
10081009
const bool mean_pooling,
10091010
const uint32_t max_D_cache,
10101011
const FixedDivisor fd_num_warps_per_table,
1011-
const int64_t* __restrict__ const indices,
1012+
const {{ index_type }}* __restrict__ const indices,
10121013
{%- if weighted %}
10131014
const float* __restrict__ const index_weights,
10141015
{%- endif %}
1015-
const int64_t* __restrict__ const offsets,
1016+
const {{ index_type }}* __restrict__ const offsets,
10161017
const uint32_t* __restrict__ const D_offsets,
10171018
const int64_t* __restrict__ const weights_offsets,
10181019
const int32_t* __restrict__ const lxu_cache_locations,
@@ -1022,3 +1023,4 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel
10221023
{%- endfor %}
10231024
{%- endfor %}
10241025
{%- endfor %}
1026+
{%- endfor %}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,8 @@ batch_index_select_dim0_codegen_forward_cuda(
549549
return output;
550550
}
551551

552+
553+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "batched_embedding{{ ndesc }}_forward_kernel_1", [&] {
552554
DISPATCH_EMB_CACHE_OUTPUT_TYPES(
553555
dev_weights.scalar_type(),
554556
{%- if not dense %}
@@ -590,7 +592,7 @@ batch_index_select_dim0_codegen_forward_cuda(
590592
emb_t,
591593
cache_t,
592594
output_t,
593-
int64_t,
595+
index_t,
594596
kEmbeddingSize / 4>
595597
<<<
596598
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
@@ -611,9 +613,9 @@ batch_index_select_dim0_codegen_forward_cuda(
611613
D,
612614
{%- endif %}
613615
FixedDivisor(B),
614-
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
616+
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
615617
{%- if not is_index_select %}
616-
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
618+
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
617619
{%- endif %}
618620
{%- if not dense %}
619621
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
@@ -644,9 +646,9 @@ batch_index_select_dim0_codegen_forward_cuda(
644646
645647
{{ nobag_kernel }}
646648
{%- if dense or is_index_select %}
647-
<emb_t, cache_t, output_t, int64_t>
649+
<emb_t, cache_t, output_t, index_t>
648650
{%- else %}
649-
<emb_t, cache_t, output_t, use_cache_t, int64_t>
651+
<emb_t, cache_t, output_t, use_cache_t, index_t>
650652
{%- endif %}
651653
<<<
652654
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
@@ -667,9 +669,9 @@ batch_index_select_dim0_codegen_forward_cuda(
667669
D,
668670
{%- endif %}
669671
FixedDivisor(B),
670-
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
672+
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
671673
{%- if not is_index_select %}
672-
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
674+
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
673675
{%- endif %}
674676
{%- if not dense %}
675677
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
@@ -717,7 +719,7 @@ batch_index_select_dim0_codegen_forward_cuda(
717719
{%- if not dense%}
718720
use_cache_t,
719721
{%- endif %}
720-
int64_t,
722+
index_t,
721723
kMaxVecsPerThread,
722724
kThreadGroupSize>
723725
<<<
@@ -742,8 +744,8 @@ batch_index_select_dim0_codegen_forward_cuda(
742744
{%- else %}
743745
FixedDivisor(B),
744746
{%- endif %}
745-
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
746-
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
747+
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
748+
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
747749
pooling_mode,
748750
{%- if weighted %}
749751
MAKE_PTA_ACC_WITH_NAME(func_name, indice_weights, cache_t, 1, 32),
@@ -784,9 +786,9 @@ batch_index_select_dim0_codegen_forward_cuda(
784786
785787
const auto kernel_func =
786788
(use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel<
787-
emb_t, cache_t, output_t, int64_t, true>
789+
emb_t, cache_t, output_t, index_t, true>
788790
: split_embedding_codegen_forward_{{ wdesc }}_v2_kernel<
789-
emb_t, cache_t, output_t, int64_t, false>);
791+
emb_t, cache_t, output_t, index_t, false>);
790792
791793
kernel_func
792794
<<<
@@ -804,12 +806,12 @@ batch_index_select_dim0_codegen_forward_cuda(
804806
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN,
805807
use_lxu_cache ? lxu_cache_weights.size(1) : 0,
806808
FixedDivisor(num_warps_per_table),
807-
indices.data_ptr<int64_t>(),
809+
indices.data_ptr<index_t>(),
808810
{%- if weighted %}
809811
// TODO: update indice_weights type
810812
indice_weights.data_ptr<float>(),
811813
{%- endif %}
812-
offsets.data_ptr<int64_t>(),
814+
offsets.data_ptr<index_t>(),
813815
reinterpret_cast<uint32_t*>(D_offsets.data_ptr<int32_t>()),
814816
weights_offsets.data_ptr<int64_t>(),
815817
lxu_cache_locations.data_ptr<int32_t>(),
@@ -819,7 +821,7 @@ batch_index_select_dim0_codegen_forward_cuda(
819821
}
820822
{%- endif %} // if has_experimental
821823
});
822-
824+
});
823825
return output;
824826
}
825827

fbgemm_gpu/include/fbgemm_gpu/utils/cpu_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ namespace fbgemm_gpu {
2222
* scale_bias_last == false that can take -1 indices (output from
2323
* pruned embedding id mapping)
2424
*/
25-
template <typename IndexType>
25+
template <typename IndexType, typename OffsetType>
2626
void report_embedding_error(
2727
int t,
2828
int B,
2929
int b_begin,
3030
int b_end,
31-
const IndexType* offsets_data,
31+
const OffsetType* offsets_data,
3232
const IndexType* indices_data,
3333
int64_t hash_size,
3434
bool allow_minus_one = false) {

0 commit comments

Comments
 (0)