Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ using namespace flashinfer;
} \
}()

#define _DISPATCH_CASE(case_expr, var, ...) \
case case_expr: { \
constexpr auto var = case_expr; \
return __VA_ARGS__(); \
#define _DISPATCH_CASE(case_expr, case_var, ...) \
case case_expr: { \
constexpr auto case_var = case_expr; \
return __VA_ARGS__(); \
}

#define DISPATCH_group_size(expr, const_expr, ...) \
Expand Down
28 changes: 13 additions & 15 deletions python/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,79 +27,77 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
for _ in args.head_dims
]
)
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(...) \\
dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(const_var, ...) \\
{dispatch_head_dims_entries}
// EOL
"""
# group sizes
dispatch_group_sizes_entries = "\n".join(
[
" _DISPATCH_CASE({}, GROUP_SIZE, __VA_ARGS__) \\".format(_)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.group_sizes
]
)
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(...) \\
dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(case_var, ...) \\
{dispatch_group_sizes_entries}
// EOL
"""
# page sizes
dispatch_page_sizes_entries = "\n".join(
[
" _DISPATCH_CASE({}, PAGE_SIZE, __VA_ARGS__) \\".format(_)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(_)
for _ in args.page_sizes
]
)
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(...) \\
dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(case_var, ...) \\
{dispatch_page_sizes_entries}
// EOL
"""
# kv layouts
dispatch_kv_layouts_entries = "\n".join(
[
" _DISPATCH_CASE({}, KV_LAYOUT, __VA_ARGS__) \\".format(
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
kv_layout_literal[_]
)
for _ in args.kv_layouts
]
)
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(...) \\
dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(case_var, ...) \\
{dispatch_kv_layouts_entries}
// EOL
"""
# positional encoding modes
dispatch_pos_encoding_modes_entries = "\n".join(
[
" _DISPATCH_CASE({}, POS_ENCODING_MODE, __VA_ARGS__) \\".format(
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(
pos_encoding_mode_literal[_]
)
for _ in args.pos_encoding_modes
]
)
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(...) \\
dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(case_var, ...) \\
{dispatch_pos_encoding_modes_entries}
// EOL
"""
# allow fp16 qk reductions
dispatch_allow_fp16_qk_reduction_entries = "\n".join(
[
" _DISPATCH_CASE({}, ALLOW_FP16_QK_REDUCTION, __VA_ARGS__) \\".format(
bool_literal[_]
)
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
for _ in args.allow_fp16_qk_reductions
]
)
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(...) \\
dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(case_var, ...) \\
{dispatch_allow_fp16_qk_reduction_entries}
// EOL
"""
# causal
dispatch_causal_entries = "\n".join(
[
" _DISPATCH_CASE({}, CAUSAL, __VA_ARGS__) \\".format(bool_literal[_])
" _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\".format(bool_literal[_])
for _ in args.causals
]
)
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(...) \\
dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\
{dispatch_causal_entries}
// EOL
"""
Expand Down