@@ -166,20 +166,22 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
166
166
}
167
167
168
168
Tensor output;
169
- const int kINT8QparamsBytes = 8 ;
170
169
SparseType o_dtype = static_cast <SparseType>(output_dtype);
171
170
TORCH_CHECK (o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::INT8 || o_dtype == SparseType::BF16);
172
171
bool output_is_bf16 = o_dtype == SparseType::BF16;
172
+ bool output_is_int8 = o_dtype == SparseType::INT8;
173
173
{% if not nobag %}
174
+ const int kINT8QparamsBytes = 8 ;
174
175
int64_t total_adjusted_D = total_D;
175
176
if (o_dtype == SparseType::INT8) {
176
177
total_adjusted_D += T * kINT8QparamsBytes ;
177
178
}
178
179
output = at::empty ({B, total_adjusted_D}, dev_weights.options ().dtype (getScalarType (o_dtype)).pinned_memory (pinned_memory));
179
180
{% else %}
181
+ const int kINT8QparamsBytes = 4 ; // no bag int8 output aligns with fbgemm weights storage size and layout
180
182
int64_t adjusted_D = D;
181
183
if (o_dtype == SparseType::INT8) {
182
- adjusted_D += T * kINT8QparamsBytes ;
184
+ adjusted_D += kINT8QparamsBytes ;
183
185
}
184
186
output = at::empty ({total_L, adjusted_D}, dev_weights.options ().dtype (getScalarType (o_dtype)).pinned_memory (pinned_memory));
185
187
@@ -202,11 +204,15 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
202
204
203
205
using float16 = uint16_t ;
204
206
using bfloat16 = uint16_t ;
205
- using fbgemm_out_t = typename std::conditional<
207
+ using int8 = uint8_t ;
208
+ using base_fbgemm_out_t = typename std::conditional<
209
+ std::is_same<output_t , at::Half>::value,
210
+ float16,
211
+ std::conditional<std::is_same<output_t , at::BFloat16>::value, bfloat16, std::conditional<std::is_same<output_t , float >::value, float , int8>::type> ::type >::type;
212
+ using other_fbgemm_out_t = typename std::conditional<
206
213
std::is_same<output_t , at::Half>::value,
207
214
float16,
208
215
std::conditional<std::is_same<output_t , at::BFloat16>::value, bfloat16, float >::type >::type;
209
-
210
216
AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " int_nbit_split_embedding{{ " _nobag" if nobag else " " }}_codegen_forward_" , [&] {
211
217
const auto * indices_acc = indices.data_ptr <index_t >();
212
218
const auto * offsets_acc = offsets.data_ptr <index_t >();
@@ -224,7 +230,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
224
230
const int32_t D_end = D_offsets_acc[t + 1 ];
225
231
const int32_t D = D_end - D_start;
226
232
{% else %}
227
- const int32_t D_start = offsets_acc[t * B] * D ;
233
+ const int32_t D_start = offsets_acc[t * B] * adjusted_D ;
228
234
{% endif %}
229
235
230
236
const auto placement = static_cast <PlacementType>(weights_placements_ptr[t]);
@@ -233,6 +239,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
233
239
weights_acc = weight_tensor.data_ptr <uint8_t >();
234
240
const uint8_t * weights = &weights_acc[weights_offsets_acc[t]];
235
241
const auto weight_ty = static_cast <SparseType>(weights_tys_acc[t]);
242
+ if (output_is_int8) {
243
+ TORCH_CHECK (weight_ty == SparseType::INT8, " int8 output are only supported for int8 weights" );
244
+ }
236
245
// default to 1 byte alignment for CPU TBE
237
246
const int32_t D_bytes = nbit::padded_row_size_in_bytes (D, weight_ty, row_alignment);
238
247
@@ -246,6 +255,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
246
255
const bool normalize_by_lengths = static_cast <PoolingMode>(pooling_mode) == PoolingMode::MEAN;
247
256
248
257
const index_t index_size = offsets_acc[(t + 1 ) * B] - *offsets_begin_ptr;
258
+ const int32_t output_stride = {{ " total_D" if not nobag else " adjusted_D" }};
249
259
250
260
{% if nobag %}
251
261
// Create virtual offsets for the nobag case. Lengths are all ones.
@@ -256,6 +266,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
256
266
{% endif %}
257
267
258
268
const float * indice_weights_ptr = nullptr ;
269
+ // int8 output only enabled for nobag case with ref impl
270
+ const bool nobag_op = {{ " false" if not nobag else " output_is_int8" }};
259
271
{% if weighted %}
260
272
indice_weights_ptr = indice_weights_acc + *offsets_begin_ptr;
261
273
{% endif %}
@@ -266,6 +278,13 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
266
278
if use_base else (" GenerateEmbeddingSpMDMNBitWithStrides"
267
279
if use_nbit else " GenerateEmbeddingSpMDMFP8WithStrides" )
268
280
%}
281
+ using fbgemm_out_t = {{ " base_fbgemm_out_t" if use_base else " other_fbgemm_out_t" }};
282
+ // TODO: merge nobag int8 path with normal asmjit dispatch
283
+ {% if nobag %}
284
+ const index_t * offset_ptr = (output_is_int8)? offsets_begin_ptr: offsets_nobag_ptr;
285
+ {% else %}
286
+ const index_t * offset_ptr = offsets_begin_ptr;
287
+ {% endif %}
269
288
const auto kernel = fbgemm::{{ kernel_name }}<
270
289
{% if use_base %}
271
290
{{ weight_type }},
@@ -292,7 +311,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
292
311
{% endif %}
293
312
/* is_weight_positional=*/ false ,
294
313
/* use_offsets=*/ true ,
295
- /* output_stride=*/ {{ " total_D " if not nobag else " D " }} ,
314
+ /* output_stride=*/ output_stride ,
296
315
/* input_stride=*/ D_bytes / sizeof ({{ weight_type }}),
297
316
{% if use_fp8 %}
298
317
/* exponent_bits=*/ fp8_exponent_bits,
@@ -302,7 +321,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
302
321
/* scale_bias_last=*/ false ,
303
322
{% endif %}
304
323
{% if use_base %}
305
- /* no_bag=*/ false ,
324
+ /* no_bag=*/ nobag_op ,
306
325
{% endif %}
307
326
/* is_bf16_out=*/ output_is_bf16
308
327
);
@@ -312,7 +331,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
312
331
num_rows,
313
332
reinterpret_cast <const {{ weight_type }}*>(weights),
314
333
indices_acc + *offsets_begin_ptr,
315
- {{ " offsets_begin_ptr " if not nobag else " offsets_nobag_ptr " }} ,
334
+ offset_ptr ,
316
335
indice_weights_ptr,
317
336
reinterpret_cast <fbgemm_out_t *>(output_acc + D_start));
318
337
{% endmacro %}
0 commit comments