Skip to content

Commit ca63ea1

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int64_t indices and offsets in TBE inference [8/N] (pytorch#331)
Summary: Pull Request resolved: facebookresearch/FBGEMM#331 X-link: pytorch#3233 - Update tests to use int64_t indices and offsets Reviewed By: jianyuh Differential Revision: D63807049 fbshipit-source-id: 286b385149a0d8563bd0f6ab250ccb7328573d26
1 parent ab3374d commit ca63ea1

File tree

6 files changed

+140
-78
lines changed

6 files changed

+140
-78
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -70,62 +70,66 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
7070
const int32_t B = (offsets.size(0) - 1) / T;
7171
TORCH_CHECK(B > 0);
7272

73-
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
74-
using uidx_t =
75-
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;
76-
77-
const auto* indices_acc = indices.data_ptr<index_t>();
78-
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
79-
const auto* offsets_acc = offsets.data_ptr<index_t>();
80-
81-
auto hash_table_acc = hash_table.accessor<int64_t, 2>();
82-
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
83-
84-
for (const auto t : c10::irange(T)) {
85-
const auto table_start = hash_table_offsets_acc[t];
86-
const auto table_end = hash_table_offsets_acc[t + 1];
87-
if (table_start == table_end) {
88-
continue;
89-
}
90-
const auto capacity = table_end - table_start;
91-
92-
for (const auto b : c10::irange(B)) {
93-
const auto indices_start = offsets_acc[t * B + b];
94-
const auto indices_end = offsets_acc[t * B + b + 1];
95-
const auto L = indices_end - indices_start;
96-
97-
for (const auto l : c10::irange(L)) {
98-
const auto idx = indices_acc[indices_start + l];
99-
const auto dense_idx = dense_indices_acc[indices_start + l];
100-
if (dense_idx == -1) {
101-
// -1 means this row has been pruned, do not insert it.
102-
continue;
103-
}
73+
AT_DISPATCH_INDEX_TYPES(hash_table.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
74+
using hash_t = index_t;
10475

105-
auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
106-
while (true) {
107-
const auto ht_idx = table_start + static_cast<int64_t>(slot);
108-
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
109-
110-
// Empty slot
111-
if (slot_sparse_idx == -1) {
112-
hash_table_acc[ht_idx][0] = idx;
113-
hash_table_acc[ht_idx][1] = dense_idx;
114-
break;
76+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "pruned_hashmap_insert_{{ wdesc }}_cpu", [&] {
77+
using uidx_t =
78+
std::conditional_t<std::is_same_v<index_t, int64_t>, uint64_t, uint32_t>;
79+
80+
const auto* indices_acc = indices.data_ptr<index_t>();
81+
const auto* dense_indices_acc = dense_indices.data_ptr<index_t>();
82+
const auto* offsets_acc = offsets.data_ptr<index_t>();
83+
84+
auto hash_table_acc = hash_table.accessor<hash_t, 2>();
85+
const auto hash_table_offsets_acc = hash_table_offsets.accessor<int64_t, 1>();
86+
87+
for (const auto t : c10::irange(T)) {
88+
const auto table_start = hash_table_offsets_acc[t];
89+
const auto table_end = hash_table_offsets_acc[t + 1];
90+
if (table_start == table_end) {
91+
continue;
92+
}
93+
const auto capacity = table_end - table_start;
94+
95+
for (const auto b : c10::irange(B)) {
96+
const auto indices_start = offsets_acc[t * B + b];
97+
const auto indices_end = offsets_acc[t * B + b + 1];
98+
const auto L = indices_end - indices_start;
99+
100+
for (const auto l : c10::irange(L)) {
101+
const auto idx = indices_acc[indices_start + l];
102+
const auto dense_idx = dense_indices_acc[indices_start + l];
103+
if (dense_idx == -1) {
104+
// -1 means this row has been pruned, do not insert it.
105+
continue;
115106
}
116-
117-
// Already exists (shouldn't happen in practice)
118-
if (slot_sparse_idx == idx) {
119-
hash_table_acc[ht_idx][1] = dense_idx;
120-
break;
107+
108+
auto slot = pruned_hash_function(static_cast<uidx_t>(idx)) % capacity;
109+
while (true) {
110+
const auto ht_idx = table_start + static_cast<int64_t>(slot);
111+
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
112+
113+
// Empty slot
114+
if (slot_sparse_idx == -1) {
115+
hash_table_acc[ht_idx][0] = idx;
116+
hash_table_acc[ht_idx][1] = dense_idx;
117+
break;
118+
}
119+
120+
// Already exists (shouldn't happen in practice)
121+
if (slot_sparse_idx == idx) {
122+
hash_table_acc[ht_idx][1] = dense_idx;
123+
break;
124+
}
125+
126+
// Linear probe
127+
slot = (slot + 1) % capacity;
121128
}
122-
123-
// Linear probe
124-
slot = (slot + 1) % capacity;
125129
}
126130
}
127131
}
128-
}
132+
});
129133
});
130134

131135
return;

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_lookup.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
100100
}
101101
}
102102

103-
template <typename index_t, typename hash_t>
103+
template <typename index_t, typename remap_t>
104104
__global__
105105
__launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pruned_array_lookup_kernel(
106106
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
107107
indices,
108108
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
109109
offsets,
110-
const pta::PackedTensorAccessor32<hash_t, 1, at::RestrictPtrTraits>
110+
const pta::PackedTensorAccessor32<remap_t, 1, at::RestrictPtrTraits>
111111
index_remappings,
112112
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
113113
index_remappings_offsets,
@@ -231,7 +231,7 @@ Tensor pruned_array_lookup_cuda(
231231
232232
AT_DISPATCH_INDEX_TYPES(
233233
index_remappings.scalar_type(), "pruned_array_lookup_cuda_0", [&] {
234-
using hash_t = index_t;
234+
using remap_t = index_t;
235235
236236
AT_DISPATCH_INDEX_TYPES(
237237
indices.scalar_type(), "pruned_array_lookup_cuda_1", [&] {
@@ -249,7 +249,7 @@ Tensor pruned_array_lookup_cuda(
249249
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
250250
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
251251
MAKE_PTA_WITH_NAME(
252-
func_name, index_remappings, hash_t, 1, 32),
252+
func_name, index_remappings, remap_t, 1, 32),
253253
MAKE_PTA_WITH_NAME(
254254
func_name, index_remappings_offsets, int64_t, 1, 32),
255255
B,

fbgemm_gpu/test/tbe/inference/nbit_cache_test.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,12 @@ def test_nbit_cache_update_function(self, L: int, H: int, S: int) -> None:
118118
self.assertEqual(total_access_count, expected_total_access)
119119

120120
@unittest.skipIf(*gpu_unavailable)
121-
@given(N=st.integers(min_value=1, max_value=8))
121+
@given(
122+
N=st.integers(min_value=1, max_value=8),
123+
indices_dtype=st.sampled_from([torch.int, torch.long]),
124+
)
122125
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
123-
def test_nbit_cache_miss_counter(self, N: int) -> None:
126+
def test_nbit_cache_miss_counter(self, N: int, indices_dtype: torch.dtype) -> None:
124127
# Create an abstract split table
125128
D = 8
126129
T = 2
@@ -156,7 +159,7 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
156159
):
157160
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
158161
for _ in range(N):
159-
cc(indices.int(), offsets.int())
162+
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
160163
(
161164
cache_miss_forward_count,
162165
unique_cache_miss_count,
@@ -173,9 +176,12 @@ def test_nbit_cache_miss_counter(self, N: int) -> None:
173176
@given(
174177
N=st.integers(min_value=1, max_value=8),
175178
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
179+
indices_dtype=st.sampled_from([torch.int, torch.long]),
176180
)
177181
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
178-
def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
182+
def test_nbit_uvm_cache_stats(
183+
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
184+
) -> None:
179185
# Create an abstract split table
180186
D = 8
181187
T = 2
@@ -215,7 +221,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
215221
for _ in range(N):
216222
num_calls_expected = num_calls_expected + 1
217223
num_indices_expcted = num_indices_expcted + len(indices)
218-
cc(indices.int(), offsets.int())
224+
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
219225
(
220226
num_calls,
221227
num_indices,
@@ -271,7 +277,7 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
271277
for x, e in zip((indices1, indices2, indices3), expected):
272278
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
273279
for _ in range(N):
274-
cc1(indices.int(), offsets.int())
280+
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
275281
(
276282
_,
277283
_,
@@ -288,10 +294,11 @@ def test_nbit_uvm_cache_stats(self, N: int, dtype: SparseType) -> None:
288294
@given(
289295
N=st.integers(min_value=1, max_value=8),
290296
dtype=st.sampled_from([SparseType.INT8, SparseType.INT4, SparseType.INT2]),
297+
indices_dtype=st.sampled_from([torch.int, torch.long]),
291298
)
292299
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
293300
def test_nbit_direct_mapped_uvm_cache_stats(
294-
self, N: int, dtype: SparseType
301+
self, N: int, dtype: SparseType, indices_dtype: torch.dtype
295302
) -> None:
296303
# Create an abstract split table
297304
D = 8
@@ -333,7 +340,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
333340
for _ in range(N):
334341
num_calls_expected = num_calls_expected + 1
335342
num_indices_expcted = num_indices_expcted + len(indices)
336-
cc(indices.int(), offsets.int())
343+
cc(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
337344
(
338345
num_calls,
339346
num_indices,
@@ -393,7 +400,7 @@ def test_nbit_direct_mapped_uvm_cache_stats(
393400
for x, e in zip((indices1, indices2, indices3), expected):
394401
(indices, offsets) = get_table_batched_offsets_from_dense(x, use_cpu=False)
395402
for _ in range(N):
396-
cc1(indices.int(), offsets.int())
403+
cc1(indices.to(dtype=indices_dtype), offsets.to(dtype=indices_dtype))
397404
(
398405
_,
399406
_,

fbgemm_gpu/test/tbe/inference/nbit_forward_autovec_test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def execute_nbit_forward_( # noqa C901
105105
use_array_for_index_remapping: bool,
106106
do_pruning: bool,
107107
mixed_weights_ty: bool,
108+
indices_dtype: torch.dtype,
108109
output_dtype: SparseType,
109110
) -> None:
110111
# NOTE: weighted operation can be done only for SUM.
@@ -311,19 +312,22 @@ def execute_nbit_forward_( # noqa C901
311312
fp8_config=fp8_config if has_fp8_weight else None,
312313
)
313314

315+
indices = indices.to(dtype=indices_dtype)
316+
offsets = offsets.to(dtype=indices_dtype)
317+
314318
if not use_cpu:
315319
fc2 = (
316-
cc(indices.int(), offsets.int())
320+
cc(indices, offsets)
317321
if not weighted
318-
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1))
322+
else cc(indices, offsets, xw.contiguous().view(-1))
319323
)
320324
else:
321325
cc = cc.cpu()
322326
indices, offsets = indices.cpu(), offsets.cpu()
323327
fc2 = (
324-
cc(indices.int(), offsets.int())
328+
cc(indices, offsets)
325329
if not weighted
326-
else cc(indices.int(), offsets.int(), xw.contiguous().view(-1).cpu())
330+
else cc(indices, offsets, xw.contiguous().view(-1).cpu())
327331
)
328332

329333
if do_pooling and B == 0:
@@ -373,6 +377,7 @@ def execute_nbit_forward_( # noqa C901
373377
pooling_mode=st.sampled_from(
374378
[PoolingMode.SUM, PoolingMode.MEAN, PoolingMode.NONE]
375379
),
380+
indices_dtype=st.sampled_from([torch.int32, torch.int64]),
376381
output_dtype=st.sampled_from(
377382
[SparseType.FP32, SparseType.FP16, SparseType.BF16]
378383
),
@@ -386,6 +391,7 @@ def test_nbit_forward_cpu_autovec(
386391
self,
387392
nbit_weights_ty: Optional[SparseType],
388393
pooling_mode: PoolingMode,
394+
indices_dtype: torch.dtype,
389395
output_dtype: SparseType,
390396
) -> None:
391397
use_cpu = True
@@ -432,6 +438,7 @@ def test_nbit_forward_cpu_autovec(
432438
False,
433439
False,
434440
mixed_weights_ty,
441+
indices_dtype,
435442
output_dtype,
436443
)
437444

0 commit comments

Comments
 (0)