Skip to content

Commit ebde42d

Browse files
excelle08facebook-github-bot
authored andcommitted
Enable bf16 output in TBE CPU kernel for other input types (pytorch#1851)
Summary: Pull Request resolved: pytorch#1851 Enable bf16 output support in TBE CPU kernel when the input weight type is int8/fp8/fp16/fp32 Differential Revision: D47028021 fbshipit-source-id: 73802deb5e17d4da4b84c7d75c6a5c3c415b46d6
1 parent cc40430 commit ebde42d

File tree

7 files changed

+160
-99
lines changed

7 files changed

+160
-99
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_cpu_template.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,6 @@ for (const auto t : c10::irange(T)) {
229229
// default to 1 byte alignment for CPU TBE
230230
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);
231231

232-
// NOTE: currently we only support bf16 output when input is int4 or int2
233-
TORCH_CHECK(o_dtype != SparseType::BF16 || (o_dtype == SparseType::BF16 && (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2)));
234-
235232
int tt;
236233
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
237234
size_t num_rows = ((tt == T ? weight_tensor.numel() : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
@@ -268,10 +265,13 @@ for (const auto t : c10::irange(T)) {
268265
{% endif %}
269266
/*input_stride=*/D_bytes / sizeof(float),
270267
{% if not nobag %}
271-
/*scale_bias_last=*/false);
268+
/*scale_bias_last=*/false,
269+
/*no_bag=*/false,
270+
/*is_bf16_out=*/output_is_bf16);
272271
{% else %}
273272
/*scale_bias_last=*/false,
274-
/*no_bag=*/true);
273+
/*no_bag=*/true,
274+
/*is_bf16_out=*/output_is_bf16);
275275
{% endif %}
276276
success = kernel(
277277
{% if not nobag %}
@@ -301,10 +301,13 @@ for (const auto t : c10::irange(T)) {
301301
{% endif %}
302302
/*input_stride=*/D_bytes / sizeof(float16),
303303
{% if not nobag %}
304-
/*scale_bias_last=*/false);
304+
/*scale_bias_last=*/false,
305+
/*no_bag=*/false,
306+
/*is_bf16_out=*/output_is_bf16);
305307
{% else %}
306308
/*scale_bias_last=*/false,
307-
/*no_bag=*/true);
309+
/*no_bag=*/true,
310+
/*is_bf16_out=*/output_is_bf16);
308311
{% endif %}
309312
success = kernel(
310313
{% if not nobag %}
@@ -333,7 +336,8 @@ for (const auto t : c10::irange(T)) {
333336
{% endif %}
334337
/*input_stride=*/D_bytes / sizeof(uint8_t),
335338
/*exponent_bits=*/fp8_exponent_bits,
336-
/*exponent_bias=*/fp8_exponent_bias);
339+
/*exponent_bias=*/fp8_exponent_bias,
340+
/*is_bf16_out=*/output_is_bf16);
337341
success = kernel(
338342
B,
339343
index_size,
@@ -358,10 +362,13 @@ for (const auto t : c10::irange(T)) {
358362
{% endif %}
359363
/*input_stride=*/D_bytes / sizeof(uint8_t),
360364
{% if not nobag %}
361-
/*scale_bias_last=*/false);
365+
/*scale_bias_last=*/false,
366+
/*no_bag=*/false,
367+
/*is_bf16_out=*/output_is_bf16);
362368
{% else %}
363369
/*scale_bias_last=*/false,
364-
/*no_bag=*/true);
370+
/*no_bag=*/true,
371+
/*is_bf16_out=*/output_is_bf16);
365372
{% endif %}
366373
success = kernel(
367374
{% if not nobag %}

fbgemm_gpu/test/split_table_batched_embeddings_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4357,7 +4357,7 @@ def test_nbit_forward_cpu(
43574357
)
43584358

43594359
@given(
4360-
nbit_weights_ty=st.sampled_from([SparseType.INT4, SparseType.INT2]),
4360+
nbit_weights_ty=get_nbit_weights_ty(),
43614361
use_array_for_index_remapping=st.booleans(),
43624362
do_pruning=st.booleans(),
43634363
)

include/fbgemm/FbgemmEmbedding.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ GenerateEmbeddingSpMDM(
8080
int prefetch = 16,
8181
bool is_weight_positional = false,
8282
bool use_offsets = true,
83-
bool isbf16 = false);
83+
bool is_bf16_out = false,
84+
bool is_br16_in = false);
8485

8586
/**
8687
* @param output_stride If -1, output_stride is same as block_size
@@ -112,7 +113,8 @@ GenerateEmbeddingSpMDMWithStrides(
112113
std::int64_t input_stride = -1,
113114
bool scale_bias_last = true,
114115
bool no_bag = false,
115-
bool isbf16 = false);
116+
bool is_bf16_out = false,
117+
bool is_bf16_in = false);
116118

117119
/**
118120
* @tparam IndexType can be int32_t or int64_t
@@ -195,7 +197,8 @@ GenerateEmbeddingSpMDMFP8WithStrides(
195197
std::int64_t output_stride = -1,
196198
std::int64_t input_stride = -1,
197199
int exponent_bits = 4,
198-
int exponent_bias = 7);
200+
int exponent_bias = 7,
201+
bool is_bf16_out = false);
199202

200203
template <
201204
typename InType,

0 commit comments

Comments
 (0)