Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new mistral #7425

Merged
merged 27 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions llm/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到utils里面

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# You may obtain a copy of the License at
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认过训练的loss正常?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
没问题,这是 8 卡 sft

#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM


def get_convert_example(model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistral有chat_template吗。确认支持吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if isinstance(model, LoRAModel) or isinstance(model, PrefixModelForCausalLM):
base_model_prefix = model.model.base_model_prefix
else:
base_model_prefix = model.base_model_prefix

if base_model_prefix == "chatglm":
return convert_example_chatglm
elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mistral"]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama."
)


class DataFormatError(ValueError):
pass


def tokenize_example(tokenizer, example, data_args):
if "src" in example and "tgt" in example:
source = example["src"][0] if isinstance(example["src"], list) else example["src"]
target = example["tgt"][0] if isinstance(example["tgt"], list) else example["tgt"]
else:
raise DataFormatError(
f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py "
)
tokenized_source = tokenizer(
source,
max_length=data_args.src_length,
truncation=True,
truncation_side="left",
add_special_tokens=True,
)
tgt_max_length = data_args.max_length - len(tokenized_source["input_ids"])
tokenized_target = tokenizer(
target,
max_length=tgt_max_length,
truncation=True,
truncation_side="right",
add_special_tokens=False,
)

tokenized_target_input_ids = tokenized_target["input_ids"]
# Add eos_token_id at the end of sequence if the sentence is not truncated.
# Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id.
if len(tokenized_target_input_ids) < tgt_max_length:
tokenized_target_input_ids += [tokenizer.eos_token_id]

return tokenized_source, tokenized_target_input_ids


def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)

if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
source_length = len(tokenized_source["input_ids"])
labels = [-100] * source_length + input_ids[source_length:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if intokens:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

return features


def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False):

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
if is_test:
return {
**tokenized_source,
"labels": tokenized_target_input_ids,
}
else:
input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids
bos_position = len(tokenized_source["input_ids"]) - 1
labels = [-100] * bos_position + input_ids[bos_position:]
# shift input_ids and labels
input_ids, labels = input_ids[:-1], labels[1:]
features = {
"input_ids": input_ids,
"labels": labels,
}

if intokens:
seq_length = len(input_ids)
# attention_mask
attention_mask = np.tri(seq_length, seq_length, dtype=bool)
attention_mask[:, :bos_position] = 1
features["attention_mask"] = attention_mask
# 2d position_ids
position_ids = np.arange(seq_length, dtype=np.int64)
block_position_ids = np.concatenate(
[
np.zeros(bos_position, dtype=np.int64),
np.arange(1, seq_length - bos_position + 1, dtype=np.int64),
]
)
features["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)

return features
30 changes: 30 additions & 0 deletions llm/mistral/lora_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到config文件里面,config目录里新增mistral包含json和readme。顺带更新llm目录下的readme

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"model_name_or_path": "mistralai/Mistral-7B-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-04,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"recompute": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true
}
30 changes: 30 additions & 0 deletions llm/mistral/pt_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"model_name_or_path": "mistralai/Mistral-7B-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_pt_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-02,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"prefix_tuning": true
}
29 changes: 29 additions & 0 deletions llm/mistral/sft_argument.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"model_name_or_path": "mistralai/Mistral-7B-v0.1",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/mistral_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 3e-05,
"warmup_steps": 30,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 4,
"pipeline_parallel_degree": 1
}
1 change: 1 addition & 0 deletions llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_convert_example(model):
"opt",
"qwen",
"mixtral",
"mistral",
"gemma",
"qwen2",
"qwen2_moe",
Expand Down
20 changes: 20 additions & 0 deletions llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def get_prefix_tuning_params(model):
hidden_size = model.config.hidden_size
postprocess_past_key_value = llama_postprocess_past_key_value
multi_query_group_num = None
elif model.base_model_prefix == "mistral":
from paddlenlp.peft.prefix import mistral_postprocess_past_key_value

num_attention_heads = model.config.num_attention_heads
num_hidden_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size
postprocess_past_key_value = mistral_postprocess_past_key_value
multi_query_group_num = model.config.num_key_value_heads
elif model.base_model_prefix == "qwen":
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value

Expand Down Expand Up @@ -190,6 +198,17 @@ def get_lora_target_modules(model):
".*w2.*",
".*w3.*",
]
elif model.base_model_prefix == "mistral":
target_modules = [
".*q_proj.*",
".*k_proj.*",
".*v_proj.*",
".*o_proj.*",
".*gate.*",
".*w1.*",
".*w2.*",
".*w3.*",
]
elif model.base_model_prefix == "qwen2_moe":
target_modules = [
".*q_proj.*",
Expand Down Expand Up @@ -279,6 +298,7 @@ def prediction_step(
)[0]
all_preds = []
for pred_tokens in generated_tokens:
pred_tokens = pred_tokens.numpy()
pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
all_preds.append(pred_tokens)
max_pred_length = max([len(x) for x in all_preds])
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/peft/prefix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@
bloom_postprocess_past_key_value,
chatglm_postprocess_past_key_value,
llama_postprocess_past_key_value,
mistral_postprocess_past_key_value,
qwen_postprocess_past_key_value,
)
7 changes: 7 additions & 0 deletions paddlenlp/peft/prefix/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
return tuple(zip(keys, values))


def mistral_postprocess_past_key_value(past_key_values):
# (layer_num, bs, head_num/tensor_parallel_degree, prefixlen, head_dim)*2
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 3, 1, 4]).split(2)

Check warning on line 43 in paddlenlp/peft/prefix/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/prefix/utils.py#L43

Added line #L43 was not covered by tests

return tuple(zip(keys, values))

Check warning on line 45 in paddlenlp/peft/prefix/utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/prefix/utils.py#L45

Added line #L45 was not covered by tests


def qwen_postprocess_past_key_value(past_key_values):
# (layer_num, bs, prefixlen, head_num/tensor_parallel_degree, head_dim)*2
keys, values = paddle.transpose(past_key_values, perm=[2, 0, 1, 3, 4]).split(2)
Expand Down
2 changes: 2 additions & 0 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@
from .rw.modeling import *
from .rw.configuration import *
from .rw.tokenizer import *
from .mistral.modeling import *
from .mistral.configuration import *
from .qwen import *
from .mixtral.modeling import *
from .mixtral.configuration import *
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/auto/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
("Blip", "blip"),
("Bloom", "bloom"),
("QWen", "qwen"),
("Mistral", "mistral"),
("Mixtral", "mixtral"),
("Qwen2", "qwen2"),
("Qwen2Moe", "qwen2_moe"),
Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/transformers/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration import MistralConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistral为什么没有tokenizer文件

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from .modeling import MistralForCausalLM
69 changes: 69 additions & 0 deletions paddlenlp/transformers/mistral/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mistral model configuration"""

from ..configuration_utils import PretrainedConfig


class MistralConfig(PretrainedConfig):
model_type = "mistral"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

Check warning on line 54 in paddlenlp/transformers/mistral/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/mistral/configuration.py#L54

Added line #L54 was not covered by tests

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
Loading
Loading