66#include < c10/util/Unroll.h>
77#include < torch/all.h>
88
9- #define QTYPE_DISPATCH (TYPE, NAME, ...) \
9+ #define QTYPE_DISPATCH (TYPE, ...) \
1010 [&]() { \
1111 switch (TYPE) { \
1212 case c10::ScalarType::Float8_e4m3fn: { \
1818 return __VA_ARGS__ (); \
1919 } \
2020 default : \
21- std::cerr << NAME << " unsupported qtype " << int (TYPE) << std::endl; \
22- exit (1 ); \
21+ TORCH_CHECK (false , " scaled_embeding_bag: unsupport qtype" ); \
2322 } \
2423 }()
2524
26- #define OUTTYPE_DISPATCH (TYPE, NAME, ...) \
25+ #define OUTTYPE_DISPATCH (TYPE, ...) \
2726 [&]() { \
2827 switch (TYPE) { \
2928 case c10::ScalarType::Float: { \
3534 return __VA_ARGS__ (); \
3635 } \
3736 default : \
38- std::cerr << NAME << " unsupported out_type " << int (TYPE) << std::endl; \
39- exit (1 ); \
37+ TORCH_CHECK (false , " scaled_embeding_bag: unsupport output type" ); \
4038 } \
4139 }()
4240
@@ -88,7 +86,7 @@ static inline CHUNK load_chunk(const int8_t *x) {
8886 return {x0, x1, x2, x3, x4, x5, x6, x7};
8987}
9088
91- static inline void save_chunk (float *output, CHUNK chunk) {
89+ static inline void store_chunk (float *output, CHUNK chunk) {
9290 __m512 x0, x1, x2, x3, x4, x5, x6, x7;
9391 std::tie (x0, x1, x2, x3, x4, x5, x6, x7) = chunk;
9492 _mm512_store_ps (output, x0);
@@ -101,7 +99,7 @@ static inline void save_chunk(float *output, CHUNK chunk) {
10199 _mm512_store_ps (output + 112 , x7);
102100}
103101
104- static inline void save_chunk (int8_t *output, CHUNK chunk) {
102+ static inline void store_chunk (int8_t *output, CHUNK chunk) {
105103 __m512i x00, x64;
106104 __m512i y0, y1, y2, y3, y4, y5, y6, y7;
107105 __m512 f0, f1, f2, f3, f4, f5, f6, f7;
@@ -135,9 +133,11 @@ static inline void save_chunk(int8_t *output, CHUNK chunk) {
135133}
136134#endif
137135
138- static inline void save_elem (float &out, float input) { out = input; }
136+ static inline void store_elem (float &out, float input) {
137+ out = input;
138+ }
139139
140- static inline void save_elem (int8_t &out, float input) {
140+ static inline void store_elem (int8_t &out, float input) {
141141 float rounded = std::round (input);
142142 float clamped = std::max (-128 .0f , std::min (127 .0f , rounded));
143143 int32_t int32_value = static_cast <int32_t >(clamped);
@@ -189,7 +189,7 @@ inline void _scaled_embedding_bag_krnl(
189189 x6 = _mm512_mul_ps (x6, scale_v);
190190 x7 = _mm512_mul_ps (x7, scale_v);
191191 // store
192- save_chunk (block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
192+ store_chunk (block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
193193 }
194194 result += num_emb * emb_dim;
195195 }
@@ -209,7 +209,7 @@ inline void _scaled_embedding_bag_krnl(
209209 value += float (weight[idx + d]);
210210 }
211211 value = value * scale;
212- save_elem (result[d], value);
212+ store_elem (result[d], value);
213213 }
214214 result += num_emb * emb_dim;
215215 }
@@ -288,8 +288,8 @@ at::Tensor _scaled_embedding_bag_impl(
288288
289289 at::Tensor output =
290290 at::empty ({batch_size, emb_dim}, qweight.options ().dtype (output_dtype));
291- OUTTYPE_DISPATCH (output_dtype, " _scaled_embedding_bag " , [&] {
292- QTYPE_DISPATCH (qtype, " _scaled_embedding_bag " , [&] {
291+ OUTTYPE_DISPATCH (output_dtype, [&] {
292+ QTYPE_DISPATCH (qtype, [&] {
293293 AT_DISPATCH_INDEX_TYPES (
294294 indices.scalar_type (), " _scaled_embedding_bag" , [&] {
295295 _scaled_embedding_bag_dispatch_dtype<index_t , data_t , output_t >(
0 commit comments