@@ -27,79 +27,77 @@ def get_dispatch_inc_str(args: argparse.Namespace) -> str:
2727 for _ in args .head_dims
2828 ]
2929 )
30- dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(...) \\
30+ dispatch_head_dims_str = f"""#define _DISPATCH_CASES_head_dim(const_var, ...) \\
3131{ dispatch_head_dims_entries }
3232// EOL
3333"""
3434 # group sizes
3535 dispatch_group_sizes_entries = "\n " .join (
3636 [
37- " _DISPATCH_CASE({}, GROUP_SIZE , __VA_ARGS__) \\ " .format (_ )
37+ " _DISPATCH_CASE({}, case_var , __VA_ARGS__) \\ " .format (_ )
3838 for _ in args .group_sizes
3939 ]
4040 )
41- dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(...) \\
41+ dispatch_group_sizes_str = f"""#define _DISPATCH_CASES_group_size(case_var, ...) \\
4242{ dispatch_group_sizes_entries }
4343// EOL
4444"""
4545 # page sizes
4646 dispatch_page_sizes_entries = "\n " .join (
4747 [
48- " _DISPATCH_CASE({}, PAGE_SIZE , __VA_ARGS__) \\ " .format (_ )
48+ " _DISPATCH_CASE({}, case_var , __VA_ARGS__) \\ " .format (_ )
4949 for _ in args .page_sizes
5050 ]
5151 )
52- dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(...) \\
52+ dispatch_page_sizes_str = f"""#define _DISPATCH_CASES_page_size(case_var, ...) \\
5353{ dispatch_page_sizes_entries }
5454// EOL
5555"""
5656 # kv layouts
5757 dispatch_kv_layouts_entries = "\n " .join (
5858 [
59- " _DISPATCH_CASE({}, KV_LAYOUT , __VA_ARGS__) \\ " .format (
59+ " _DISPATCH_CASE({}, case_var , __VA_ARGS__) \\ " .format (
6060 kv_layout_literal [_ ]
6161 )
6262 for _ in args .kv_layouts
6363 ]
6464 )
65- dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(...) \\
65+ dispatch_kv_layouts_str = f"""#define _DISPATCH_CASES_kv_layout(case_var, ...) \\
6666{ dispatch_kv_layouts_entries }
6767// EOL
6868"""
6969 # positional encoding modes
7070 dispatch_pos_encoding_modes_entries = "\n " .join (
7171 [
72- " _DISPATCH_CASE({}, POS_ENCODING_MODE , __VA_ARGS__) \\ " .format (
72+ " _DISPATCH_CASE({}, case_var , __VA_ARGS__) \\ " .format (
7373 pos_encoding_mode_literal [_ ]
7474 )
7575 for _ in args .pos_encoding_modes
7676 ]
7777 )
78- dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(...) \\
78+ dispatch_pos_encoding_modes_str = f"""#define _DISPATCH_CASES_pos_encoding_mode(case_var, ...) \\
7979{ dispatch_pos_encoding_modes_entries }
8080// EOL
8181"""
8282 # allow fp16 qk reductions
8383 dispatch_allow_fp16_qk_reduction_entries = "\n " .join (
8484 [
85- " _DISPATCH_CASE({}, ALLOW_FP16_QK_REDUCTION, __VA_ARGS__) \\ " .format (
86- bool_literal [_ ]
87- )
85+ " _DISPATCH_CASE({}, case_var, __VA_ARGS__) \\ " .format (bool_literal [_ ])
8886 for _ in args .allow_fp16_qk_reductions
8987 ]
9088 )
91- dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(...) \\
89+ dispatch_allow_fp16_qk_reductions_str = f"""#define _DISPATCH_CASES_allow_fp16_qk_reduction(case_var, ...) \\
9290{ dispatch_allow_fp16_qk_reduction_entries }
9391// EOL
9492"""
9593 # causal
9694 dispatch_causal_entries = "\n " .join (
9795 [
98- " _DISPATCH_CASE({}, CAUSAL , __VA_ARGS__) \\ " .format (bool_literal [_ ])
96+ " _DISPATCH_CASE({}, case_var , __VA_ARGS__) \\ " .format (bool_literal [_ ])
9997 for _ in args .causals
10098 ]
10199 )
102- dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(...) \\
100+ dispatch_causal_str = f"""#define _DISPATCH_CASES_causal(case_var, ...) \\
103101{ dispatch_causal_entries }
104102// EOL
105103"""
0 commit comments