Skip to content

Commit 94bcf6f

Browse files
authored
fix: fix macro to suppress compilation warning (#231)
There are some mistakes in our macro definitions which results in lots of warnings and potential bugs. This PR fixes the issue.
1 parent 11ca502 commit 94bcf6f

File tree

2 files changed

+17
-19
lines changed

2 files changed

+17
-19
lines changed

python/csrc/pytorch_extension_utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ using namespace flashinfer;
108108
} \
109109
}()
110110

111-
#define _DISPATCH_CASE(case_expr, var, ...) \
112-
case case_expr: { \
113-
constexpr auto var = case_expr; \
114-
return __VA_ARGS__(); \
111+
#define _DISPATCH_CASE(case_expr, case_var, ...) \
112+
case case_expr: { \
113+
constexpr auto case_var = case_expr; \
114+
return __VA_ARGS__(); \
115115
}
116116

117117
#define DISPATCH_group_size(expr, const_expr, ...) \

python/generate_dispatch_inc.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,79 +27,77 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
2727
for _ in args.head_dims
2828
]
2929
)
30-
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(...) \\
30+
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(const_var, ...) \\
3131
{dispatch_head_dims_entries}
3232
// EOL
3333
"""
3434
# group sizes
3535
dispatch_group_sizes_entries = "\n".join(
3636
[
37-
" _DISPATCH_CASE({}, GROUP_SIZE, __VA_ARGS__) \\".format(_)
37+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
3838
for _ in args.group_sizes
3939
]
4040
)
41-
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(...) \\
41+
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(case_var, ...) \\
4242
{dispatch_group_sizes_entries}
4343
// EOL
4444
"""
4545
# page sizes
4646
dispatch_page_sizes_entries = "\n".join(
4747
[
48-
" _DISPATCH_CASE({}, PAGE_SIZE, __VA_ARGS__) \\".format(_)
48+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
4949
for _ in args.page_sizes
5050
]
5151
)
52-
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(...) \\
52+
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(case_var, ...) \\
5353
{dispatch_page_sizes_entries}
5454
// EOL
5555
"""
5656
# kv layouts
5757
dispatch_kv_layouts_entries = "\n".join(
5858
[
59-
" _DISPATCH_CASE({}, KV_LAYOUT, __VA_ARGS__) \\".format(
59+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
6060
kv_layout_literal[_]
6161
)
6262
for _ in args.kv_layouts
6363
]
6464
)
65-
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(...) \\
65+
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(case_var, ...) \\
6666
{dispatch_kv_layouts_entries}
6767
// EOL
6868
"""
6969
# positional encoding modes
7070
dispatch_pos_encoding_modes_entries = "\n".join(
7171
[
72-
" _DISPATCH_CASE({}, POS_ENCODING_MODE, __VA_ARGS__) \\".format(
72+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
7373
pos_encoding_mode_literal[_]
7474
)
7575
for _ in args.pos_encoding_modes
7676
]
7777
)
78-
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(...) \\
78+
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(case_var, ...) \\
7979
{dispatch_pos_encoding_modes_entries}
8080
// EOL
8181
"""
8282
# allow fp16 qk reductions
8383
dispatch_allow_fp16_qk_reduction_entries = "\n".join(
8484
[
85-
" _DISPATCH_CASE({}, ALLOW_FP16_QK_REDUCTION, __VA_ARGS__) \\".format(
86-
bool_literal[_]
87-
)
85+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
8886
for _ in args.allow_fp16_qk_reductions
8987
]
9088
)
91-
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(...) \\
89+
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(case_var, ...) \\
9290
{dispatch_allow_fp16_qk_reduction_entries}
9391
// EOL
9492
"""
9593
# causal
9694
dispatch_causal_entries = "\n".join(
9795
[
98-
" _DISPATCH_CASE({}, CAUSAL, __VA_ARGS__) \\".format(bool_literal[_])
96+
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
9997
for _ in args.causals
10098
]
10199
)
102-
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(...) \\
100+
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\
103101
{dispatch_causal_entries}
104102
// EOL
105103
"""

0 commit comments

Comments
 (0)