Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Jul 25, 2024
1 parent e7d96a0 commit b7f32af
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 160 deletions.
12 changes: 6 additions & 6 deletions paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,20 +389,20 @@ def set_state_dict(self, state_dict, use_structured_name=True):
head_dim = embed_dim // config.num_attention_heads

for k, v in state_dict.items():
if k.startswith("transformer.word_embeddings.weight"):
if k.startswith("chatglm.transformer.word_embeddings.weight"):

Check warning on line 392 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L392

Added line #L392 was not covered by tests
self.word_embeddings.weight.set_value(v.astype(dtype))
continue
elif k.startswith("transformer.final_layernorm.weight"):
elif k.startswith("chatglm.transformer.final_layernorm.weight"):

Check warning on line 395 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L395

Added line #L395 was not covered by tests
self.transformer_block.ffn_ln_scales[config.num_hidden_layers - 1].set_value(v.astype("float32"))
continue
elif k.startswith("transformer.final_layernorm.bias"):
elif k.startswith("chatglm.transformer.final_layernorm.bias"):

Check warning on line 398 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L398

Added line #L398 was not covered by tests
self.transformer_block.ffn_ln_biases[config.num_hidden_layers - 1].set_value(v.astype("float32"))
continue
elif k.startswith("lm_head.weight"):
continue
elif k.endswith("rotary_embeddings.inv_freq") or k.endswith("rotary_emb.inv_freq"):
continue
idx = int(k.split(".")[2])
idx = int(k.split(".")[3])

Check warning on line 405 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L405

Added line #L405 was not covered by tests
if k.endswith("input_layernorm.weight"):
if idx == 0:
self.input_layernorm.weight.set_value(v.astype(dtype))
Expand Down Expand Up @@ -584,7 +584,7 @@ def __init__(self, config: ChatGLMConfig):

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs)
return infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs, return_numpy=False)

Check warning on line 587 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L587

Added line #L587 was not covered by tests

@classmethod
def get_cache_kvs_shape(
Expand Down Expand Up @@ -745,6 +745,6 @@ def forward(
@paddle.no_grad()
def set_state_dict(self, state_dict):
self.lm_head.weight.set_value(
state_dict["transformer.word_embeddings.weight"].astype(self.lm_head.weight.dtype)
state_dict["chatglm.transformer.word_embeddings.weight"].astype(self.lm_head.weight.dtype)
)
self.model.transformer.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
153 changes: 4 additions & 149 deletions paddlenlp/experimental/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,168 +13,23 @@
# limitations under the License.
from __future__ import annotations

Check warning on line 14 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L14

Added line #L14 was not covered by tests

import json
import os

Check warning on line 16 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L16

Added line #L16 was not covered by tests
from functools import partial

import numpy as np
import paddle

Check warning on line 18 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L18

Added line #L18 was not covered by tests
from tqdm import tqdm

from paddlenlp.transformers import AutoConfig
from paddlenlp.transformers.model_utils import (

Check warning on line 20 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L20

Added line #L20 was not covered by tests
_add_variant,
dtype_guard,
load_state_dict,
load_tp_checkpoint,
no_init_weights,
)
from paddlenlp.transformers.utils import (

Check warning on line 25 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L25

Added line #L25 was not covered by tests
ContextManagers,
is_paddle_support_lazy_init,
is_safetensors_available,
paddlenlp_load,
)
from paddlenlp.utils.env import (
PADDLE_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
)

try:
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
except:
from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file


def load_sharded_checkpoint(folder, variant=None, return_numpy=False):
"""
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
variant (`str`): The model variant.
return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
"""
# Load the index
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
if os.path.isfile(pdparams_file):
return paddle.load(pdparams_file, return_numpy=return_numpy)
if os.path.isfile(lora_pdparams_file):
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
if os.path.isfile(safetensors_file):
state_dict = safe_load_file(safetensors_file)
if not return_numpy:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
return state_dict

index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))

index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
safe_master_present = os.path.isfile(safe_master_file)
safe_peft_present = os.path.isfile(safe_peft_file)

load_safe = False
load_index = None
if safe_index_present:
load_safe = True # load safe due to preference
load_index = safe_index_file
elif safe_master_present:
load_safe = True
load_index = safe_master_file
elif index_present:
load_index = index_file
elif safe_peft_present:
load_safe = True
load_index = safe_peft_file
else:
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")

with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)

shard_files = list(set(index["weight_map"].values()))
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")

ret = {}
for shard_file in tqdm(shard_files):
state_dict = loader(os.path.join(folder, shard_file))
ret.update(state_dict)

if not return_numpy:
for key in list(ret.keys()):
if isinstance(ret[key], np.ndarray):
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)

return ret


def load_tp_checkpoint(folder, cls, config, return_numpy=False):
"""
This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
Args:
folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
cls (`str`): The model class.
config (`AutoConfig`): The model config.
return_numpy (bool): Whether load the tp checkpoint as numpy.
"""

config = AutoConfig.from_pretrained(folder)
if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1:
return load_sharded_checkpoint(folder, return_numpy=return_numpy)
else:
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
model_path = os.path.join(folder, "model_state.pdparams")
safe_model_path = os.path.join(folder, "model.safetensors")
if os.path.exists(rank_model_path):
return paddle.load(rank_model_path, return_numpy=return_numpy)
elif os.path.exists(model_path):
state_dict = cls.convert_tensor_parallel(model_path, config)
elif os.path.exists(safe_model_path):
with safe_open(safe_model_path, framework="np", device="cpu") as f:
loaded_keys = f.keys()
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
state_dict = load_state_dict(safe_model_path, tp_actions)
else: # shard files safetensors
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
pretrained_model_name_or_path=folder,
use_safetensors=True,
)
if len(resolved_sharded_files) > 1:
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
state_dict = {}
for shard_file in resolved_sharded_files:
shard_state_dict = load_state_dict(
shard_file,
tp_actions,
loaded_state_dict_keys,
)
state_dict.update(shard_state_dict)
if return_numpy:
for k in list(state_dict.keys()):
if not isinstance(state_dict[k], np.ndarray):
state_dict[k] = state_dict.pop(k).cpu().numpy()
return state_dict


def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs):
def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args, kwargs, return_numpy=True):

Check warning on line 32 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L32

Added line #L32 was not covered by tests
r"""
Instantiate a pretrained model configuration from a pre-trained model name or path.
"""
Expand Down Expand Up @@ -203,7 +58,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
with ContextManagers(init_contexts):
model = cls(config)

Check warning on line 59 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L58-L59

Added lines #L58 - L59 were not covered by tests

resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(
resolved_archive_file, _, _, _ = cls._resolve_model_file_path(

Check warning on line 61 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L61

Added line #L61 was not covered by tests
pretrained_model_name_or_path,
cache_dir=cache_dir,
subfolder=subfolder,
Expand All @@ -216,7 +71,7 @@ def infererence_model_from_pretrained(cls, pretrained_model_name_or_path, args,
)

model_path = os.path.dirname(resolved_archive_file)
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=True)
state_dict = load_tp_checkpoint(model_path, cls, config, return_numpy=return_numpy)
model.set_state_dict(state_dict)

Check warning on line 75 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L73-L75

Added lines #L73 - L75 were not covered by tests

return model

Check warning on line 77 in paddlenlp/experimental/transformers/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/utils.py#L77

Added line #L77 was not covered by tests
135 changes: 130 additions & 5 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import os
import re
import sys
import tempfile
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -59,6 +58,8 @@
PADDLE_WEIGHTS_NAME,
PYTORCH_WEIGHTS_INDEX_NAME,
PYTORCH_WEIGHTS_NAME,
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_PEFT_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
)
Expand Down Expand Up @@ -109,13 +110,14 @@ def unwrap_optimizer(optimizer, optimizer_instances=()):


if is_safetensors_available():
from safetensors.numpy import load_file as safe_load_file
from safetensors.numpy import save_file as safe_save_file

if sys.platform.startswith("win"):
from safetensors import safe_open
else:
try:
from paddlenlp.utils.safetensors import fast_load_file as safe_load_file
from paddlenlp.utils.safetensors import fast_safe_open as safe_open
except:
from safetensors import safe_open
from safetensors.numpy import load_file as safe_load_file

Check warning on line 120 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L118-L120

Added lines #L118 - L120 were not covered by tests


def prune_linear_layer(layer: nn.Linear, index: paddle.Tensor, dim: int = 0) -> nn.Linear:
Expand Down Expand Up @@ -2665,3 +2667,126 @@ def set_state_dict(self, state_dict, *args, **kwargs):

ret = super().set_state_dict(state_dict, *args, **kwargs)
return ret


def load_sharded_checkpoint_new(folder, variant=None, return_numpy=False):
"""
This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
loaded in the model.
Args:
folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
variant (`str`): The model variant.
return_numpy (`bool`): Whether to return numpy array instead of paddle tensor.
"""
# Load the index
pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant))
lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant))
safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant))
if os.path.isfile(pdparams_file):
return paddle.load(pdparams_file, return_numpy=return_numpy)
if os.path.isfile(lora_pdparams_file):
return paddle.load(lora_pdparams_file, return_numpy=return_numpy)
if os.path.isfile(safetensors_file):
state_dict = safe_load_file(safetensors_file)
if not return_numpy:
for key in list(state_dict.keys()):
if isinstance(state_dict[key], np.ndarray):
state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True)
return state_dict

Check warning on line 2698 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2685-L2698

Added lines #L2685 - L2698 were not covered by tests

index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant))
safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant))
safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant))

Check warning on line 2703 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2700-L2703

Added lines #L2700 - L2703 were not covered by tests

index_present = os.path.isfile(index_file)
safe_index_present = os.path.isfile(safe_index_file)
safe_master_present = os.path.isfile(safe_master_file)
safe_peft_present = os.path.isfile(safe_peft_file)

Check warning on line 2708 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2705-L2708

Added lines #L2705 - L2708 were not covered by tests

load_safe = False
load_index = None
if safe_index_present:
load_safe = True # load safe due to preference
load_index = safe_index_file
elif safe_master_present:
load_safe = True
load_index = safe_master_file
elif index_present:
load_index = index_file
elif safe_peft_present:
load_safe = True
load_index = safe_peft_file

Check warning on line 2722 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2710-L2722

Added lines #L2710 - L2722 were not covered by tests
else:
raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}")

Check warning on line 2724 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2724

Added line #L2724 was not covered by tests

with open(load_index, "r", encoding="utf-8") as f:
index = json.load(f)

Check warning on line 2727 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2726-L2727

Added lines #L2726 - L2727 were not covered by tests

shard_files = list(set(index["weight_map"].values()))
loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu")

Check warning on line 2730 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2729-L2730

Added lines #L2729 - L2730 were not covered by tests

ret = {}
for shard_file in tqdm(shard_files):
state_dict = loader(os.path.join(folder, shard_file))
ret.update(state_dict)

Check warning on line 2735 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2732-L2735

Added lines #L2732 - L2735 were not covered by tests

if not return_numpy:
for key in list(ret.keys()):
if isinstance(ret[key], np.ndarray):
ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True)

Check warning on line 2740 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2737-L2740

Added lines #L2737 - L2740 were not covered by tests

return ret

Check warning on line 2742 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2742

Added line #L2742 was not covered by tests


def load_tp_checkpoint(folder, cls, config, return_numpy=False):
"""
This load is performed efficiently: Load tp checkpoint only from cpu, no need to init the model.
Args:
folder (`str` or `os.PathLike`): A path to a folder containing the model checkpoint.
cls (`str`): The model class.
config (`AutoConfig`): The model config.
return_numpy (bool): Whether load the tp checkpoint as numpy.
"""
if config.tensor_parallel_degree == 1 or config.tensor_parallel_degree == -1:
return load_sharded_checkpoint_new(folder, return_numpy=return_numpy)

Check warning on line 2757 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2756-L2757

Added lines #L2756 - L2757 were not covered by tests
else:
rank_model_path = os.path.join(folder, f"model_state.tp0{config.tensor_parallel_rank}.pdparams")
model_path = os.path.join(folder, "model_state.pdparams")
safe_model_path = os.path.join(folder, "model.safetensors")
if os.path.exists(rank_model_path):
return paddle.load(rank_model_path, return_numpy=return_numpy)
elif os.path.exists(model_path):
state_dict = cls.convert_tensor_parallel(model_path, config)
elif os.path.exists(safe_model_path):
with safe_open(safe_model_path, framework="np", device="cpu") as f:
loaded_keys = f.keys()
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_keys)
state_dict = load_state_dict(safe_model_path, tp_actions)

Check warning on line 2770 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2759-L2770

Added lines #L2759 - L2770 were not covered by tests
else: # shard files safetensors
resolved_archive_file, resolved_sharded_files, sharded_metadata, is_sharded = cls._resolve_model_file_path(

Check warning on line 2772 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2772

Added line #L2772 was not covered by tests
pretrained_model_name_or_path=folder,
use_safetensors=True,
)
if len(resolved_sharded_files) > 1:
resolved_sharded_files = tqdm(resolved_sharded_files, desc="Loading checkpoint shards")
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
tp_actions = cls.get_tensor_parallel_convert_actions(config, loaded_state_dict_keys, ignore_error=True)
state_dict = {}
for shard_file in resolved_sharded_files:
shard_state_dict = load_state_dict(

Check warning on line 2782 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2776-L2782

Added lines #L2776 - L2782 were not covered by tests
shard_file,
tp_actions,
loaded_state_dict_keys,
)
state_dict.update(shard_state_dict)
if return_numpy:
for k in list(state_dict.keys()):
if not isinstance(state_dict[k], np.ndarray):
state_dict[k] = state_dict.pop(k).cpu().numpy()
return state_dict

Check warning on line 2792 in paddlenlp/transformers/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/model_utils.py#L2787-L2792

Added lines #L2787 - L2792 were not covered by tests

0 comments on commit b7f32af

Please sign in to comment.