Skip to content

Commit ea763e9

Browse files
Leon Gaofacebook-github-bot
authored andcommitted
int8 output for seq embeddings (#2316)
Summary: * int8 output dtype is a gap for recently fbgemm usage case, setup a reasonable refimplementation first, memcpy based. * for sequence embedding, we first unblock dispatch via simple memcpy, it is a pure bw op(no dequant) so memcpy should be reasonably ok. further optimization like ILP via unrolling, try avx non-temp instruction, rep instruction to be done in future iterations. Differential Revision: D53449813
1 parent c8cb1a2 commit ea763e9

File tree

4 files changed

+201
-67
lines changed

4 files changed

+201
-67
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,20 +166,22 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
166166
}
167167

168168
Tensor output;
169-
const int kINT8QparamsBytes = 8;
170169
SparseType o_dtype = static_cast<SparseType>(output_dtype);
171170
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
172171
bool output_is_bf16 = o_dtype == SparseType::BF16;
172+
bool output_is_int8 = o_dtype == SparseType::INT8;
173173
{% if not nobag %}
174+
const int kINT8QparamsBytes = 8;
174175
int64_t total_adjusted_D = total_D;
175176
if (o_dtype == SparseType::INT8) {
176177
total_adjusted_D += T * kINT8QparamsBytes;
177178
}
178179
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
179180
{% else %}
181+
const int kINT8QparamsBytes = 4; // no bag int8 output aligns with fbgemm weights storage size and layout
180182
int64_t adjusted_D = D;
181183
if (o_dtype == SparseType::INT8) {
182-
adjusted_D += T * kINT8QparamsBytes;
184+
adjusted_D += kINT8QparamsBytes;
183185
}
184186
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)).pinned_memory(pinned_memory));
185187

@@ -202,11 +204,15 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
202204

203205
using float16 = uint16_t;
204206
using bfloat16 = uint16_t;
205-
using fbgemm_out_t = typename std::conditional<
207+
using int8 = uint8_t;
208+
using base_fbgemm_out_t = typename std::conditional<
209+
std::is_same<output_t, at::Half>::value,
210+
float16,
211+
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, std::conditional<std::is_same<output_t, float>::value, float, int8>::type> ::type >::type;
212+
using other_fbgemm_out_t = typename std::conditional<
206213
std::is_same<output_t, at::Half>::value,
207214
float16,
208215
std::conditional<std::is_same<output_t, at::BFloat16>::value, bfloat16, float>::type >::type;
209-
210216
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
211217
const auto* indices_acc = indices.data_ptr<index_t>();
212218
const auto* offsets_acc = offsets.data_ptr<index_t>();
@@ -224,7 +230,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
224230
const int32_t D_end = D_offsets_acc[t + 1];
225231
const int32_t D = D_end - D_start;
226232
{% else %}
227-
const int32_t D_start = offsets_acc[t * B] * D;
233+
const int32_t D_start = offsets_acc[t * B] * adjusted_D;
228234
{% endif %}
229235

230236
const auto placement = static_cast<PlacementType>(weights_placements_ptr[t]);
@@ -233,6 +239,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
233239
weights_acc = weight_tensor.data_ptr<uint8_t>();
234240
const uint8_t* weights = &weights_acc[weights_offsets_acc[t]];
235241
const auto weight_ty = static_cast<SparseType>(weights_tys_acc[t]);
242+
if (output_is_int8) {
243+
TORCH_CHECK(weight_ty == SparseType::INT8, "int8 output are only supported for int8 weights");
244+
}
236245
// default to 1 byte alignment for CPU TBE
237246
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);
238247

@@ -246,6 +255,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
246255
const bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;
247256

248257
const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;
258+
const int32_t output_stride = {{ "total_D" if not nobag else "adjusted_D" }};
249259

250260
{% if nobag %}
251261
// Create virtual offsets for the nobag case. Lengths are all ones.
@@ -256,6 +266,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
256266
{% endif %}
257267

258268
const float* indice_weights_ptr = nullptr;
269+
// int8 output only enabled for nobag case with ref impl
270+
const bool nobag_op = {{ "false" if not nobag else "output_is_int8" }};
259271
{% if weighted %}
260272
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
261273
{% endif %}
@@ -266,6 +278,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
266278
if use_base else ("GenerateEmbeddingSpMDMNBitWithStrides"
267279
if use_nbit else "GenerateEmbeddingSpMDMFP8WithStrides")
268280
%}
281+
using fbgemm_out_t = {{ "base_fbgemm_out_t" if use_base else "other_fbgemm_out_t" }};
282+
// TODO: merge nobag int8 path with normal asmjit dispatch
283+
{% if nobag %}
284+
const index_t* offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
285+
{% else %}
286+
const index_t* offset_ptr = offsets_begin_ptr;
287+
{% endif %}
269288
const auto kernel = fbgemm::{{ kernel_name }}<
270289
{% if use_base %}
271290
{{ weight_type }},
@@ -292,7 +311,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
292311
{% endif %}
293312
/*is_weight_positional=*/false,
294313
/*use_offsets=*/true,
295-
/*output_stride=*/{{ "total_D" if not nobag else "D" }},
314+
/*output_stride=*/output_stride,
296315
/*input_stride=*/D_bytes / sizeof({{ weight_type }}),
297316
{% if use_fp8 %}
298317
/*exponent_bits=*/fp8_exponent_bits,
@@ -302,7 +321,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
302321
/*scale_bias_last=*/false,
303322
{% endif %}
304323
{% if use_base %}
305-
/*no_bag=*/false,
324+
/*no_bag=*/nobag_op,
306325
{% endif %}
307326
/*is_bf16_out=*/output_is_bf16
308327
);
@@ -312,7 +331,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
312331
num_rows,
313332
reinterpret_cast<const {{ weight_type }}*>(weights),
314333
indices_acc + *offsets_begin_ptr,
315-
{{ "offsets_begin_ptr" if not nobag else "offsets_nobag_ptr" }},
334+
offset_ptr,
316335
indice_weights_ptr,
317336
reinterpret_cast<fbgemm_out_t*>(output_acc + D_start));
318337
{% endmacro %}

fbgemm_gpu/test/tbe/inference/nbit_forward_test.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
import unittest
1313

14-
from typing import Callable, Dict, List, Optional
14+
from typing import Callable, Dict, List, Optional, Tuple
1515

1616
import hypothesis.strategies as st
1717
import numpy as np
@@ -91,6 +91,9 @@ def get_nbit_weights_ty(draw) -> Optional[SparseType]:
9191
"test_faketensor__test_nbit_forward_gpu_no_cache_fp8_2048": [
9292
unittest.skip("Operator not implemented for Meta tensors"),
9393
],
94+
"test_faketensor__test_nbit_forward_cpu_seq_int8": [
95+
unittest.skip("Operator not implemented for Meta tensors"),
96+
],
9497
}
9598

9699

@@ -838,6 +841,100 @@ def test_nbit_forward_uvm_cache(
838841
output_ref = cc_ref(indices, offsets)
839842
torch.testing.assert_close(output, output_ref, equal_nan=True)
840843

844+
@given(
845+
D=st.sampled_from([32, 256, 384, 512, 1024]),
846+
B=st.integers(min_value=8, max_value=32),
847+
T=st.integers(min_value=10, max_value=20),
848+
L=st.integers(min_value=10, max_value=100),
849+
MAXH=st.integers(min_value=50, max_value=100),
850+
)
851+
@settings(
852+
verbosity=VERBOSITY,
853+
max_examples=MAX_EXAMPLES_LONG_RUNNING,
854+
deadline=None,
855+
)
856+
def test_nbit_forward_cpu_seq_int8(
857+
self,
858+
D: int,
859+
B: int,
860+
T: int,
861+
L: int,
862+
MAXH: int,
863+
) -> None:
864+
"""
865+
we init a quant table split embedding bag with int8 weights and scale of 1 and 0 bias
866+
and compare brute force table lookup vs tbe based int8 output lookup.
867+
"""
868+
pooling_mode = PoolingMode.NONE
869+
870+
nbit_weights_ty = SparseType.INT8
871+
D_alignment = (
872+
1
873+
if nbit_weights_ty.bit_rate() % 8 == 0
874+
else int(8 / nbit_weights_ty.bit_rate())
875+
)
876+
D = round_up(D, D_alignment)
877+
T_H = [np.random.randint(low=1, high=MAXH + 1) for _ in range(T)]
878+
quant_cc = IntNBitTableBatchedEmbeddingBagsCodegen(
879+
embedding_specs=[
880+
(
881+
"",
882+
H,
883+
D,
884+
nbit_weights_ty,
885+
EmbeddingLocation.HOST,
886+
)
887+
for H in T_H
888+
],
889+
pooling_mode=pooling_mode,
890+
device="cpu",
891+
output_dtype=nbit_weights_ty,
892+
)
893+
# Initialize the random weights for int nbit table split embedding bag
894+
quant_cc.fill_random_weights()
895+
raw_embedding_weights = quant_cc.split_embedding_weights()
896+
# we mimic 1.0 scale, 0.0 bias for better results comparison
897+
embedding_weights: List[Tuple[torch.Tensor, Optional[torch.Tensor]]] = [
898+
(table_weight, torch.tensor([1, 0], dtype=torch.float16).view(torch.uint8))
899+
for table_weight, _ in raw_embedding_weights
900+
]
901+
# Initialize the random weights for int8 nbit table split embedding bag
902+
quant_cc.assign_embedding_weights(embedding_weights)
903+
lengths_list = [
904+
torch.randint(
905+
1,
906+
L + 1,
907+
(B,),
908+
)
909+
for _ in range(T)
910+
]
911+
indices_list = [
912+
torch.randint(0, H, (int(length.sum().item()),))
913+
for length, H in zip(lengths_list, T_H)
914+
]
915+
indices = torch.cat(indices_list, 0)
916+
lengths = torch.cat(lengths_list, 0)
917+
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
918+
quant_cc_output = quant_cc(indices.int(), offsets.int())
919+
tables_rows = [
920+
T for T, _, _ in quant_cc.split_embedding_weights_with_scale_bias(0)
921+
]
922+
ref_output = torch.cat(
923+
[
924+
table_rows[indice_table]
925+
for indice_table, table_rows in zip(indices_list, tables_rows)
926+
],
927+
dim=0,
928+
)
929+
torch.testing.assert_close(
930+
quant_cc_output.cpu(),
931+
ref_output.cpu(),
932+
rtol=1e-2,
933+
atol=1e-2,
934+
equal_nan=False,
935+
)
936+
937+
841938

842939
if __name__ == "__main__":
843940
unittest.main()

src/EmbeddingSpMDM.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,19 +1540,29 @@ GenerateEmbeddingSpMDMRowWiseSparse(
15401540
#define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
15411541
#define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
15421542

1543-
#define INSTANTIATE_SPMDM_THREAD_LOCAL( \
1544-
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1545-
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1546-
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
1547-
INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1548-
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1549-
INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1550-
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
1543+
#define INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \
1544+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1545+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1546+
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false)
1547+
1548+
#define INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \
1549+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
1550+
INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1551+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \
1552+
INSTANTIATE_SPMDM_NOSTRIDE_BASE( \
1553+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \
15511554
INSTANTIATE_SPMDMFP8_BASE_##IN_TYPE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
15521555

1553-
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
1554-
INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
1555-
INSTANTIATE_SPMDM_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \
1556+
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
1557+
INSTANTIATE_SPMDM_BASE_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
1558+
INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \
1559+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \
1560+
INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \
1561+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint8_t) \
1562+
INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \
1563+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
1564+
INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \
1565+
IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \
15561566
INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE)
15571567

15581568
#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \

0 commit comments

Comments
 (0)