@@ -58,6 +58,20 @@ def register_macro(self, name, func, priority):
5858]
5959
6060
61+ def extract_axis_and_clean_tokens (tokens ):
62+ axis = 1
63+ for idx , tkn in enumerate (tokens ):
64+ if tkn .value == "axis" and idx + 2 < len (tokens ):
65+ axis = int (tokens [idx + 2 ].value )
66+ end_idx = idx + 3
67+ if end_idx < len (tokens ) - 1 :
68+ assert tokens [end_idx ].value == ","
69+ end_idx += 1
70+ tokens = tokens [:idx ] + tokens [end_idx :]
71+ break
72+ return axis , tokens
73+
74+
6175# star_macro must be called after layer_id_macro
6276@macro (name = 'star_macro' , priority = 3 )
6377def star_macro (tokens , expression , context ):
@@ -119,12 +133,14 @@ def layer_id_macro(tokens, expression, context):
119133 )
120134 assert name_with_layer_id , "No $LAYER_ID found in NAME tokens"
121135
122- num_layers = context .get_num_hidden_layers (
136+ match_layer_id = context .get_num_hidden_layers (
123137 name_with_layer_id , LAYER_ID_MACRO_TAG
124138 )
125139 expanded_expressions = []
126140
127- for layer_id in range (num_layers ):
141+ match_layer_id = sorted (match_layer_id )
142+
143+ for layer_id in match_layer_id :
128144 expr = ""
129145 for token in tokens :
130146 if token .type == TokenType .IDENTIFIER :
@@ -181,6 +197,8 @@ def fused_qkv_old_macro(tokens, expression, context):
181197 if not any (tkn .value == FUSED_QKV_OLD_TAG for tkn in tokens ):
182198 return expression
183199
200+ axis , tokens = extract_axis_and_clean_tokens (tokens )
201+
184202 attn_head_num = None
185203 num_key_value_groups = None
186204 fused_qkv_old_pos = None
@@ -263,10 +281,14 @@ def gen_expr(tp_degree, num_heads, tp_rank, comp):
263281 for c , n in head_config
264282 ]
265283 if idx == 0 :
266- mapping = f"{ qkv_weight_name } -> { ',' .join (qkv_parts )} , axis=1"
284+ mapping = (
285+ f"{ qkv_weight_name } -> { ',' .join (qkv_parts )} , axis={ axis } "
286+ )
267287 results .append (mapping )
268288 elif qkv_weight_name is not None :
269- mapping = f"{ ',' .join (qkv_parts )} -> { qkv_weight_name } , axis=1"
289+ mapping = (
290+ f"{ ',' .join (qkv_parts )} -> { qkv_weight_name } , axis={ axis } "
291+ )
270292 results .append (mapping )
271293
272294 if fused_qkv_old_pos > 4 :
@@ -275,7 +297,7 @@ def _generate_expr(prefix, count, target_name):
275297 elements = "," .join (
276298 f"fused_qkv_old_tmp.{ prefix } _{ i } " for i in range (count )
277299 )
278- return f"{ elements } -> { target_name } , axis=1 "
300+ return f"{ elements } -> { target_name } , axis={ axis } "
279301
280302 q_name = tokens [2 ].value
281303 k_name = tokens [4 ].value
@@ -292,7 +314,7 @@ def _generate_expr(prefix, count, target_name):
292314
293315 fused_qkv_tmp_name = f"{ q_name } .{ k_name } .{ v_name } .tmp"
294316 results .append (
295- f"{ q_name } ,{ k_name } ,{ v_name } -> { fused_qkv_tmp_name } , axis=1 "
317+ f"{ q_name } ,{ k_name } ,{ v_name } -> { fused_qkv_tmp_name } , axis={ axis } "
296318 )
297319 dst_state_shard_num = context .get_dst_state_shard_num (
298320 dst_qkv_weight_name
@@ -324,9 +346,13 @@ def gen_expr(tp_degree, num_heads, tp_rank, comp):
324346 for c , n in head_config
325347 ]
326348 if idx == 0 :
327- mapping = f"{ qkv_weight_name } -> { ',' .join (qkv_parts )} , axis=1"
349+ mapping = (
350+ f"{ qkv_weight_name } -> { ',' .join (qkv_parts )} , axis={ axis } "
351+ )
328352 else :
329- mapping = f"{ ',' .join (qkv_parts )} -> { qkv_weight_name } , axis=1"
353+ mapping = (
354+ f"{ ',' .join (qkv_parts )} -> { qkv_weight_name } , axis={ axis } "
355+ )
330356 results .append (mapping )
331357 else :
332358 raise ValueError (
@@ -340,6 +366,9 @@ def fused_ffn_macro(tokens, expression, context):
340366 FUSED_FFN_TAG = "fused_ffn"
341367 if not any (tkn .value == FUSED_FFN_TAG for tkn in tokens ):
342368 return expression
369+
370+ axis , tokens = extract_axis_and_clean_tokens (tokens )
371+
343372 rarrow_pos = None
344373 fused_ffn_pos = None
345374 for idx , token in enumerate (tokens ):
@@ -388,19 +417,19 @@ def gen_expr(tp_degree, splited_num, tp_rank, comp):
388417 ]
389418 if idx == 0 :
390419 results .append (
391- f"{ ffn_weight_name } -> { ',' .join (ffn_parts )} , axis=1 "
420+ f"{ ffn_weight_name } -> { ',' .join (ffn_parts )} , axis={ axis } "
392421 )
393422 elif ffn_weight_name is not None :
394423 results .append (
395- f"{ ',' .join (ffn_parts )} -> { ffn_weight_name } , axis=1 "
424+ f"{ ',' .join (ffn_parts )} -> { ffn_weight_name } , axis={ axis } "
396425 )
397426 if fused_ffn_pos > 4 :
398427
399428 def _generate_expr (prefix , count , target_name ):
400429 elements = "," .join (
401430 f"fused_ffn_tmp.{ prefix } _{ i } " for i in range (count )
402431 )
403- return f"{ elements } -> { target_name } , axis=1 "
432+ return f"{ elements } -> { target_name } , axis={ axis } "
404433
405434 gate_name = tokens [2 ].value
406435 up_name = tokens [4 ].value
@@ -415,7 +444,7 @@ def _generate_expr(prefix, count, target_name):
415444
416445 fused_gate_up_tmp_name = f"{ gate_name } .{ up_name } .tmp"
417446 results .append (
418- f"{ gate_name } ,{ up_name } -> { fused_gate_up_tmp_name } , axis=1 "
447+ f"{ gate_name } ,{ up_name } -> { fused_gate_up_tmp_name } , axis={ axis } "
419448 )
420449 dst_state_shard_num = context .get_dst_state_shard_num (
421450 dst_ffn_weight_name
@@ -445,11 +474,11 @@ def gen_expr(tp_degree, splited_num, tp_rank, comp):
445474 ]
446475 if idx == 0 :
447476 results .append (
448- f"{ ffn_weight_name } -> { ',' .join (ffn_parts )} , axis=1 "
477+ f"{ ffn_weight_name } -> { ',' .join (ffn_parts )} , axis={ axis } "
449478 )
450479 else :
451480 results .append (
452- f"{ ',' .join (ffn_parts )} -> { ffn_weight_name } , axis=1 "
481+ f"{ ',' .join (ffn_parts )} -> { ffn_weight_name } , axis={ axis } "
453482 )
454483 else :
455484 raise ValueError (f"Unsupported fused_ffn macro format: { expression } ." )
@@ -508,6 +537,8 @@ def fused_qkv(tokens, expression, context):
508537 if not any (tkn .value == FUSED_QKV_TAG for tkn in tokens ):
509538 return expression
510539
540+ axis , tokens = extract_axis_and_clean_tokens (tokens )
541+
511542 attn_head_num = num_heads = None
512543 num_key_value_groups = None
513544 fused_qkv_pos = None
@@ -566,12 +597,12 @@ def make_names(base, n):
566597 fused_qkv_order .append (k_names [g ])
567598 fused_qkv_order .append (v_names [g ])
568599 results .append (
569- f"{ fused_qkv_var } -> { ',' .join (fused_qkv_order )} , axis=1 "
600+ f"{ fused_qkv_var } -> { ',' .join (fused_qkv_order )} , axis={ axis } "
570601 )
571602
572- results .append (f"{ ',' .join (q_names )} -> { q_var } , axis=1 " )
573- results .append (f"{ ',' .join (k_names )} -> { k_var } , axis=1 " )
574- results .append (f"{ ',' .join (v_names )} -> { v_var } , axis=1 " )
603+ results .append (f"{ ',' .join (q_names )} -> { q_var } , axis={ axis } " )
604+ results .append (f"{ ',' .join (k_names )} -> { k_var } , axis={ axis } " )
605+ results .append (f"{ ',' .join (v_names )} -> { v_var } , axis={ axis } " )
575606
576607 return results
577608
@@ -585,9 +616,9 @@ def make_names(base, n):
585616 k_names = make_names (k_var , num_key_value_groups )
586617 v_names = make_names (v_var , num_key_value_groups )
587618
588- results .append (f"{ q_var } -> { ',' .join (q_names )} , axis=1 " )
589- results .append (f"{ k_var } -> { ',' .join (k_names )} , axis=1 " )
590- results .append (f"{ v_var } -> { ',' .join (v_names )} , axis=1 " )
619+ results .append (f"{ q_var } -> { ',' .join (q_names )} , axis={ axis } " )
620+ results .append (f"{ k_var } -> { ',' .join (k_names )} , axis={ axis } " )
621+ results .append (f"{ v_var } -> { ',' .join (v_names )} , axis={ axis } " )
591622
592623 fused_qkv_order = []
593624 for g in range (num_key_value_groups ):
@@ -597,7 +628,7 @@ def make_names(base, n):
597628 fused_qkv_order .append (k_names [g ])
598629 fused_qkv_order .append (v_names [g ])
599630 results .append (
600- f"{ ',' .join (fused_qkv_order )} -> { fused_qkv_var } , axis=1 "
631+ f"{ ',' .join (fused_qkv_order )} -> { fused_qkv_var } , axis={ axis } "
601632 )
602633 return results
603634
0 commit comments