Skip to content

Commit 2c1a28a

Browse files
authored
【FlexCheckpoint】Upgrade some macros and optimize load_state_dict communication (#75282)
* upgrad macros and load_state_dict comm task fix fix support 0-d tensor fix balance save and fix * fix test
1 parent ebd9578 commit 2c1a28a

File tree

8 files changed

+444
-211
lines changed

8 files changed

+444
-211
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,9 +1343,13 @@ def _create_sharded_weight(
13431343
master_weights = optim_state_dict.pop("master_weights", None)
13441344
optim_state_dict.pop("LR_Scheduler", None)
13451345

1346-
static_to_struct = {
1347-
v.local_tensor.name: k for k, v in model_sharded_state_dict.items()
1348-
}
1346+
static_to_struct = {}
1347+
model_sharded_state_dict = dict(
1348+
sorted(model_sharded_state_dict.items())
1349+
)
1350+
for k, v in model_sharded_state_dict.items():
1351+
if v.local_tensor.name not in static_to_struct:
1352+
static_to_struct[v.local_tensor.name] = k
13491353

13501354
sharded_state = {}
13511355

python/paddle/distributed/flex_checkpoint/aoa/aoa_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def get_num_hidden_layers(
8989
)
9090
prefix, suffix = name_with_layer_id.split(layer_id_macro_tag, 1)
9191
pattern = re.compile(rf"{re.escape(prefix)}(\d+){re.escape(suffix)}")
92-
max_layer = 0
92+
match_layer_id = set()
9393
for key in self.get_all_dst_state_keys():
9494
match = pattern.fullmatch(key)
9595
if match:
9696
layer_num = int(match.group(1))
97-
max_layer = max(max_layer, layer_num)
98-
return max_layer + 1
97+
match_layer_id.add(layer_num)
98+
return match_layer_id
9999

100100
def get_src_state_shard_num(self, src_state_key: str) -> int:
101101
if src_state_key not in self.source_state_shard_info:

python/paddle/distributed/flex_checkpoint/aoa/macros.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ def register_macro(self, name, func, priority):
5858
]
5959

6060

61+
def extract_axis_and_clean_tokens(tokens):
62+
axis = 1
63+
for idx, tkn in enumerate(tokens):
64+
if tkn.value == "axis" and idx + 2 < len(tokens):
65+
axis = int(tokens[idx + 2].value)
66+
end_idx = idx + 3
67+
if end_idx < len(tokens) - 1:
68+
assert tokens[end_idx].value == ","
69+
end_idx += 1
70+
tokens = tokens[:idx] + tokens[end_idx:]
71+
break
72+
return axis, tokens
73+
74+
6175
# star_macro must be called after layer_id_macro
6276
@macro(name='star_macro', priority=3)
6377
def star_macro(tokens, expression, context):
@@ -119,12 +133,14 @@ def layer_id_macro(tokens, expression, context):
119133
)
120134
assert name_with_layer_id, "No $LAYER_ID found in NAME tokens"
121135

122-
num_layers = context.get_num_hidden_layers(
136+
match_layer_id = context.get_num_hidden_layers(
123137
name_with_layer_id, LAYER_ID_MACRO_TAG
124138
)
125139
expanded_expressions = []
126140

127-
for layer_id in range(num_layers):
141+
match_layer_id = sorted(match_layer_id)
142+
143+
for layer_id in match_layer_id:
128144
expr = ""
129145
for token in tokens:
130146
if token.type == TokenType.IDENTIFIER:
@@ -181,6 +197,8 @@ def fused_qkv_old_macro(tokens, expression, context):
181197
if not any(tkn.value == FUSED_QKV_OLD_TAG for tkn in tokens):
182198
return expression
183199

200+
axis, tokens = extract_axis_and_clean_tokens(tokens)
201+
184202
attn_head_num = None
185203
num_key_value_groups = None
186204
fused_qkv_old_pos = None
@@ -263,10 +281,14 @@ def gen_expr(tp_degree, num_heads, tp_rank, comp):
263281
for c, n in head_config
264282
]
265283
if idx == 0:
266-
mapping = f"{qkv_weight_name} -> {','.join(qkv_parts)}, axis=1"
284+
mapping = (
285+
f"{qkv_weight_name} -> {','.join(qkv_parts)}, axis={axis}"
286+
)
267287
results.append(mapping)
268288
elif qkv_weight_name is not None:
269-
mapping = f"{','.join(qkv_parts)} -> {qkv_weight_name}, axis=1"
289+
mapping = (
290+
f"{','.join(qkv_parts)} -> {qkv_weight_name}, axis={axis}"
291+
)
270292
results.append(mapping)
271293

272294
if fused_qkv_old_pos > 4:
@@ -275,7 +297,7 @@ def _generate_expr(prefix, count, target_name):
275297
elements = ",".join(
276298
f"fused_qkv_old_tmp.{prefix}_{i}" for i in range(count)
277299
)
278-
return f"{elements} -> {target_name}, axis=1"
300+
return f"{elements} -> {target_name}, axis={axis}"
279301

280302
q_name = tokens[2].value
281303
k_name = tokens[4].value
@@ -292,7 +314,7 @@ def _generate_expr(prefix, count, target_name):
292314

293315
fused_qkv_tmp_name = f"{q_name}.{k_name}.{v_name}.tmp"
294316
results.append(
295-
f"{q_name},{k_name},{v_name} -> {fused_qkv_tmp_name}, axis=1"
317+
f"{q_name},{k_name},{v_name} -> {fused_qkv_tmp_name}, axis={axis}"
296318
)
297319
dst_state_shard_num = context.get_dst_state_shard_num(
298320
dst_qkv_weight_name
@@ -324,9 +346,13 @@ def gen_expr(tp_degree, num_heads, tp_rank, comp):
324346
for c, n in head_config
325347
]
326348
if idx == 0:
327-
mapping = f"{qkv_weight_name} -> {','.join(qkv_parts)}, axis=1"
349+
mapping = (
350+
f"{qkv_weight_name} -> {','.join(qkv_parts)}, axis={axis}"
351+
)
328352
else:
329-
mapping = f"{','.join(qkv_parts)} -> {qkv_weight_name}, axis=1"
353+
mapping = (
354+
f"{','.join(qkv_parts)} -> {qkv_weight_name}, axis={axis}"
355+
)
330356
results.append(mapping)
331357
else:
332358
raise ValueError(
@@ -340,6 +366,9 @@ def fused_ffn_macro(tokens, expression, context):
340366
FUSED_FFN_TAG = "fused_ffn"
341367
if not any(tkn.value == FUSED_FFN_TAG for tkn in tokens):
342368
return expression
369+
370+
axis, tokens = extract_axis_and_clean_tokens(tokens)
371+
343372
rarrow_pos = None
344373
fused_ffn_pos = None
345374
for idx, token in enumerate(tokens):
@@ -388,19 +417,19 @@ def gen_expr(tp_degree, splited_num, tp_rank, comp):
388417
]
389418
if idx == 0:
390419
results.append(
391-
f"{ffn_weight_name} -> {','.join(ffn_parts)}, axis=1"
420+
f"{ffn_weight_name} -> {','.join(ffn_parts)}, axis={axis}"
392421
)
393422
elif ffn_weight_name is not None:
394423
results.append(
395-
f"{','.join(ffn_parts)} -> {ffn_weight_name}, axis=1"
424+
f"{','.join(ffn_parts)} -> {ffn_weight_name}, axis={axis}"
396425
)
397426
if fused_ffn_pos > 4:
398427

399428
def _generate_expr(prefix, count, target_name):
400429
elements = ",".join(
401430
f"fused_ffn_tmp.{prefix}_{i}" for i in range(count)
402431
)
403-
return f"{elements} -> {target_name}, axis=1"
432+
return f"{elements} -> {target_name}, axis={axis}"
404433

405434
gate_name = tokens[2].value
406435
up_name = tokens[4].value
@@ -415,7 +444,7 @@ def _generate_expr(prefix, count, target_name):
415444

416445
fused_gate_up_tmp_name = f"{gate_name}.{up_name}.tmp"
417446
results.append(
418-
f"{gate_name},{up_name} -> {fused_gate_up_tmp_name}, axis=1"
447+
f"{gate_name},{up_name} -> {fused_gate_up_tmp_name}, axis={axis}"
419448
)
420449
dst_state_shard_num = context.get_dst_state_shard_num(
421450
dst_ffn_weight_name
@@ -445,11 +474,11 @@ def gen_expr(tp_degree, splited_num, tp_rank, comp):
445474
]
446475
if idx == 0:
447476
results.append(
448-
f"{ffn_weight_name} -> {','.join(ffn_parts)}, axis=1"
477+
f"{ffn_weight_name} -> {','.join(ffn_parts)}, axis={axis}"
449478
)
450479
else:
451480
results.append(
452-
f"{','.join(ffn_parts)} -> {ffn_weight_name}, axis=1"
481+
f"{','.join(ffn_parts)} -> {ffn_weight_name}, axis={axis}"
453482
)
454483
else:
455484
raise ValueError(f"Unsupported fused_ffn macro format: {expression}.")
@@ -508,6 +537,8 @@ def fused_qkv(tokens, expression, context):
508537
if not any(tkn.value == FUSED_QKV_TAG for tkn in tokens):
509538
return expression
510539

540+
axis, tokens = extract_axis_and_clean_tokens(tokens)
541+
511542
attn_head_num = num_heads = None
512543
num_key_value_groups = None
513544
fused_qkv_pos = None
@@ -566,12 +597,12 @@ def make_names(base, n):
566597
fused_qkv_order.append(k_names[g])
567598
fused_qkv_order.append(v_names[g])
568599
results.append(
569-
f"{fused_qkv_var} -> {','.join(fused_qkv_order)}, axis=1"
600+
f"{fused_qkv_var} -> {','.join(fused_qkv_order)}, axis={axis}"
570601
)
571602

572-
results.append(f"{','.join(q_names)} -> {q_var}, axis=1")
573-
results.append(f"{','.join(k_names)} -> {k_var}, axis=1")
574-
results.append(f"{','.join(v_names)} -> {v_var}, axis=1")
603+
results.append(f"{','.join(q_names)} -> {q_var}, axis={axis}")
604+
results.append(f"{','.join(k_names)} -> {k_var}, axis={axis}")
605+
results.append(f"{','.join(v_names)} -> {v_var}, axis={axis}")
575606

576607
return results
577608

@@ -585,9 +616,9 @@ def make_names(base, n):
585616
k_names = make_names(k_var, num_key_value_groups)
586617
v_names = make_names(v_var, num_key_value_groups)
587618

588-
results.append(f"{q_var} -> {','.join(q_names)}, axis=1")
589-
results.append(f"{k_var} -> {','.join(k_names)}, axis=1")
590-
results.append(f"{v_var} -> {','.join(v_names)}, axis=1")
619+
results.append(f"{q_var} -> {','.join(q_names)}, axis={axis}")
620+
results.append(f"{k_var} -> {','.join(k_names)}, axis={axis}")
621+
results.append(f"{v_var} -> {','.join(v_names)}, axis={axis}")
591622

592623
fused_qkv_order = []
593624
for g in range(num_key_value_groups):
@@ -597,7 +628,7 @@ def make_names(base, n):
597628
fused_qkv_order.append(k_names[g])
598629
fused_qkv_order.append(v_names[g])
599630
results.append(
600-
f"{','.join(fused_qkv_order)} -> {fused_qkv_var}, axis=1"
631+
f"{','.join(fused_qkv_order)} -> {fused_qkv_var}, axis={axis}"
601632
)
602633
return results
603634

0 commit comments

Comments
 (0)