@@ -229,9 +229,6 @@ for (const auto t : c10::irange(T)) {
229
229
// default to 1 byte alignment for CPU TBE
230
230
const int32_t D_bytes = nbit::padded_row_size_in_bytes (D, weight_ty, row_alignment);
231
231
232
- // NOTE: currently we only support bf16 output when input is int4 or int2
233
- TORCH_CHECK (o_dtype != SparseType::BF16 || (o_dtype == SparseType::BF16 && (weight_ty == SparseType::INT4 || weight_ty == SparseType::INT2)));
234
-
235
232
int tt;
236
233
for (tt = t + 1 ; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
237
234
size_t num_rows = ((tt == T ? weight_tensor.numel () : weights_offsets_acc[tt]) - weights_offsets_acc[t]) / D_bytes;
@@ -268,10 +265,13 @@ for (const auto t : c10::irange(T)) {
268
265
{% endif %}
269
266
/* input_stride=*/ D_bytes / sizeof (float ),
270
267
{% if not nobag %}
271
- /* scale_bias_last=*/ false );
268
+ /* scale_bias_last=*/ false ,
269
+ /* no_bag=*/ false ,
270
+ /* is_bf16_out=*/ output_is_bf16);
272
271
{% else %}
273
272
/* scale_bias_last=*/ false ,
274
- /* no_bag=*/ true );
273
+ /* no_bag=*/ true ,
274
+ /* is_bf16_out=*/ output_is_bf16);
275
275
{% endif %}
276
276
success = kernel (
277
277
{% if not nobag %}
@@ -301,10 +301,13 @@ for (const auto t : c10::irange(T)) {
301
301
{% endif %}
302
302
/* input_stride=*/ D_bytes / sizeof (float16),
303
303
{% if not nobag %}
304
- /* scale_bias_last=*/ false );
304
+ /* scale_bias_last=*/ false ,
305
+ /* no_bag=*/ false ,
306
+ /* is_bf16_out=*/ output_is_bf16);
305
307
{% else %}
306
308
/* scale_bias_last=*/ false ,
307
- /* no_bag=*/ true );
309
+ /* no_bag=*/ true ,
310
+ /* is_bf16_out=*/ output_is_bf16);
308
311
{% endif %}
309
312
success = kernel (
310
313
{% if not nobag %}
@@ -333,7 +336,8 @@ for (const auto t : c10::irange(T)) {
333
336
{% endif %}
334
337
/* input_stride=*/ D_bytes / sizeof (uint8_t ),
335
338
/* exponent_bits=*/ fp8_exponent_bits,
336
- /* exponent_bias=*/ fp8_exponent_bias);
339
+ /* exponent_bias=*/ fp8_exponent_bias,
340
+ /* is_bf16_out=*/ output_is_bf16);
337
341
success = kernel (
338
342
B,
339
343
index_size,
@@ -358,10 +362,13 @@ for (const auto t : c10::irange(T)) {
358
362
{% endif %}
359
363
/* input_stride=*/ D_bytes / sizeof (uint8_t ),
360
364
{% if not nobag %}
361
- /* scale_bias_last=*/ false );
365
+ /* scale_bias_last=*/ false ,
366
+ /* no_bag=*/ false ,
367
+ /* is_bf16_out=*/ output_is_bf16);
362
368
{% else %}
363
369
/* scale_bias_last=*/ false ,
364
- /* no_bag=*/ true );
370
+ /* no_bag=*/ true ,
371
+ /* is_bf16_out=*/ output_is_bf16);
365
372
{% endif %}
366
373
success = kernel (
367
374
{% if not nobag %}
0 commit comments