Skip to content

Commit 33e3f0f

Browse files
committed
refine code
1 parent 4f8bff0 commit 33e3f0f

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

torchao/csrc/cpu/aten_kernels/scaled_embedding_bag.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <c10/util/Unroll.h>
77
#include <torch/all.h>
88

9-
#define QTYPE_DISPATCH(TYPE, NAME, ...) \
9+
#define QTYPE_DISPATCH(TYPE, ...) \
1010
[&]() { \
1111
switch (TYPE) { \
1212
case c10::ScalarType::Float8_e4m3fn: { \
@@ -18,12 +18,11 @@
1818
return __VA_ARGS__(); \
1919
} \
2020
default: \
21-
std::cerr << NAME << " unsupported qtype " << int(TYPE) << std::endl; \
22-
exit(1); \
21+
TORCH_CHECK(false, "scaled_embeding_bag: unsupport qtype"); \
2322
} \
2423
}()
2524

26-
#define OUTTYPE_DISPATCH(TYPE, NAME, ...) \
25+
#define OUTTYPE_DISPATCH(TYPE, ...) \
2726
[&]() { \
2827
switch (TYPE) { \
2928
case c10::ScalarType::Float: { \
@@ -35,8 +34,7 @@
3534
return __VA_ARGS__(); \
3635
} \
3736
default: \
38-
std::cerr << NAME << " unsupported out_type " << int(TYPE) << std::endl; \
39-
exit(1); \
37+
TORCH_CHECK(false, "scaled_embeding_bag: unsupport output type"); \
4038
} \
4139
}()
4240

@@ -88,7 +86,7 @@ static inline CHUNK load_chunk(const int8_t *x) {
8886
return {x0, x1, x2, x3, x4, x5, x6, x7};
8987
}
9088

91-
static inline void save_chunk(float *output, CHUNK chunk) {
89+
static inline void store_chunk(float *output, CHUNK chunk) {
9290
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
9391
std::tie(x0, x1, x2, x3, x4, x5, x6, x7) = chunk;
9492
_mm512_store_ps(output, x0);
@@ -101,7 +99,7 @@ static inline void save_chunk(float *output, CHUNK chunk) {
10199
_mm512_store_ps(output + 112, x7);
102100
}
103101

104-
static inline void save_chunk(int8_t *output, CHUNK chunk) {
102+
static inline void store_chunk(int8_t *output, CHUNK chunk) {
105103
__m512i x00, x64;
106104
__m512i y0, y1, y2, y3, y4, y5, y6, y7;
107105
__m512 f0, f1, f2, f3, f4, f5, f6, f7;
@@ -135,9 +133,11 @@ static inline void save_chunk(int8_t *output, CHUNK chunk) {
135133
}
136134
#endif
137135

138-
static inline void save_elem(float &out, float input) { out = input; }
136+
static inline void store_elem(float &out, float input) {
137+
out = input;
138+
}
139139

140-
static inline void save_elem(int8_t &out, float input) {
140+
static inline void store_elem(int8_t &out, float input) {
141141
float rounded = std::round(input);
142142
float clamped = std::max(-128.0f, std::min(127.0f, rounded));
143143
int32_t int32_value = static_cast<int32_t>(clamped);
@@ -189,7 +189,7 @@ inline void _scaled_embedding_bag_krnl(
189189
x6 = _mm512_mul_ps(x6, scale_v);
190190
x7 = _mm512_mul_ps(x7, scale_v);
191191
// store
192-
save_chunk(block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
192+
store_chunk(block_result, {x0, x1, x2, x3, x4, x5, x6, x7});
193193
}
194194
result += num_emb * emb_dim;
195195
}
@@ -209,7 +209,7 @@ inline void _scaled_embedding_bag_krnl(
209209
value += float(weight[idx + d]);
210210
}
211211
value = value * scale;
212-
save_elem(result[d], value);
212+
store_elem(result[d], value);
213213
}
214214
result += num_emb * emb_dim;
215215
}
@@ -288,8 +288,8 @@ at::Tensor _scaled_embedding_bag_impl(
288288

289289
at::Tensor output =
290290
at::empty({batch_size, emb_dim}, qweight.options().dtype(output_dtype));
291-
OUTTYPE_DISPATCH(output_dtype, "_scaled_embedding_bag", [&] {
292-
QTYPE_DISPATCH(qtype, "_scaled_embedding_bag", [&] {
291+
OUTTYPE_DISPATCH(output_dtype, [&] {
292+
QTYPE_DISPATCH(qtype, [&] {
293293
AT_DISPATCH_INDEX_TYPES(
294294
indices.scalar_type(), "_scaled_embedding_bag", [&] {
295295
_scaled_embedding_bag_dispatch_dtype<index_t, data_t, output_t>(

0 commit comments

Comments
 (0)