From effe7c10115a6b0f0a1f090b0218131fdf92d0a8 Mon Sep 17 00:00:00 2001 From: shw Date: Sun, 22 Sep 2024 18:05:37 +0800 Subject: [PATCH] fix name --- torchacc/utils/optim_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchacc/utils/optim_utils.py b/torchacc/utils/optim_utils.py index 71cf939..35ac13d 100644 --- a/torchacc/utils/optim_utils.py +++ b/torchacc/utils/optim_utils.py @@ -44,10 +44,16 @@ def get_layer_full_info(shard_metadata, model_state_dict): is_sharded = False name_splits = name.split(".") - # if start with 'model', we just skip the 'model' - if name_splits[0] == 'model': - name = ".".join(name_splits[1:]) - name_splits = name.split(".") + model_num = 0 + # if start with 'model', we just skip the model + for name in name_splits: + if name != 'model': + break + else: + n = n + 1 + name_splits = name_splits[n:] + name = ".".join(name_splits) + for idx, sep in enumerate(name_splits): if sep.startswith("_fsdp_shard"): is_sharded = True