-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
example(finetune): add chatglm3 evaluation and finetune example (#3025)
- Loading branch information
1 parent
8cd4de5
commit 85c2f87
Showing
11 changed files
with
376 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pretrain/ | ||
.cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
ChatGLM3 Finetune with Starwhale | ||
====== | ||
|
||
- 🍬 Parameters: 6b | ||
- 🔆 Github: https://github.com/THUDM/ChatGLM3 | ||
- 🥦 Author: THUDM. | ||
- 📝 License: Unknown | ||
- 🐱 Starwhale Example: https://github.com/star-whale/starwhale/tree/main/example/llm-finetune/models/chatglm3 | ||
- 🌽 Introduction: ChatGLM3 is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series, maintaining many excellent features of the first two generations such as smooth dialogue and low deployment threshold. | ||
|
||
In this example, we use 4bit quantization to reduce gpu memory usage, the single T4/A10/A100 gpu card is ok for evaluation and finetune. | ||
|
||
Build Starwhale Model | ||
------ | ||
|
||
```bash | ||
python3 download.py | ||
swcli model build . | ||
``` | ||
|
||
Run Online Evaluation in the Standalone instance | ||
------ | ||
|
||
```bash | ||
# for source code | ||
swcli -vvv model serve --workdir . --host 0.0.0.0 --port 10878 | ||
|
||
# for model package with runtime | ||
swcli -vvv model serve --uri chatglm3-6b --host 0.0.0.0 --port 10878 --runtime llm-finetune | ||
``` | ||
|
||
Run Starwhale Model for evaluation in the Standalone instance | ||
------ | ||
|
||
```bash | ||
# download evaluation dataset | ||
swcli dataset cp https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/ . | ||
|
||
# for source code | ||
swcli -vvv model run -w . -m evaluation --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3 | ||
|
||
# for model package | ||
swcli -vvv model run --uri chatglm3-6b --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3 --runtime llm-finetune | ||
``` | ||
|
||
|
||
Finetune base model | ||
------ | ||
|
||
```bash | ||
# build finetune dataset from baichuan2 | ||
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k | ||
|
||
# for source code | ||
swcli -vvv model run -w . -m finetune --dataset belle_chat_random_10k --handler finetune:p_tuning_v2_finetune | ||
|
||
# for model package | ||
swcli -vvv model run -u chatglm3-6b --dataset belle_chat_random_10k --handler finetune:p_tuning_v2_finetune --runtime llm-finetune | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from pathlib import Path | ||
|
||
ROOT_DIR = Path(__file__).parent | ||
PRETRAINED_MODELS_DIR = ROOT_DIR / "pretrained_models" | ||
BASE_MODEL_DIR = PRETRAINED_MODELS_DIR / "base" | ||
PT_DIR = PRETRAINED_MODELS_DIR / "p_tuning_v2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from huggingface_hub import snapshot_download | ||
|
||
try: | ||
from .consts import BASE_MODEL_DIR | ||
except ImportError: | ||
from consts import BASE_MODEL_DIR | ||
|
||
if __name__ == "__main__": | ||
snapshot_download(repo_id="THUDM/chatglm3-6b", local_dir=BASE_MODEL_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import typing as t | ||
|
||
import torch | ||
from transformers import AutoModel, AutoConfig, AutoTokenizer | ||
|
||
from starwhale import evaluation | ||
from starwhale.api.service import api, LLMChat | ||
|
||
try: | ||
from .consts import PT_DIR, BASE_MODEL_DIR | ||
except ImportError: | ||
from consts import PT_DIR, BASE_MODEL_DIR | ||
|
||
_g_model = None | ||
_g_tokenizer = None | ||
|
||
|
||
def _load_model_and_tokenizer() -> t.Tuple: | ||
global _g_model, _g_tokenizer | ||
|
||
if _g_model is None: | ||
# TODO: after starwhale supports parameters, we can remove os environ. | ||
config = AutoConfig.from_pretrained( | ||
BASE_MODEL_DIR, | ||
trust_remote_code=True, | ||
pre_seq_len=int(os.environ.get("PT_PRE_SEQ_LEN", "128")), | ||
) | ||
_g_model = ( | ||
AutoModel.from_pretrained( | ||
BASE_MODEL_DIR, | ||
config=config, | ||
device_map="cuda:0", | ||
torch_dtype=torch.float16, | ||
trust_remote_code=True, | ||
) | ||
.quantize(4) | ||
.cuda() | ||
.eval() | ||
) | ||
|
||
ptuning_path = PT_DIR / "pytorch_model.bin" | ||
if ptuning_path.exists(): | ||
print(f"load p-tuning model: {ptuning_path}") | ||
prefix_state_dict = torch.load(ptuning_path) | ||
new_prefix_state_dict = {} | ||
for k, v in prefix_state_dict.items(): | ||
if k.startswith("transformer.prefix_encoder."): | ||
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v | ||
_g_model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) | ||
|
||
if _g_tokenizer is None: | ||
_g_tokenizer = AutoTokenizer.from_pretrained( | ||
BASE_MODEL_DIR, use_fast=False, trust_remote_code=True | ||
) | ||
|
||
return _g_model, _g_tokenizer | ||
|
||
|
||
@evaluation.predict( | ||
resources={"nvidia.com/gpu": 1}, | ||
replicas=1, | ||
log_mode="plain", | ||
) | ||
def copilot_predict(data: dict) -> str: | ||
model, tokenizer = _load_model_and_tokenizer() | ||
print(data["prompt"]) | ||
response, _ = model.chat( | ||
tokenizer, | ||
data["prompt"], | ||
history=[], | ||
max_length=int(os.environ.get("MAX_LENGTH", "512")), | ||
top_p=float(os.environ.get("TOP_P", "0.9")), | ||
temperature=float(os.environ.get("TEMPERATURE", "1.2")), | ||
) | ||
return response | ||
|
||
|
||
@api( | ||
inference_type=LLMChat( | ||
args={"user_input", "history", "temperature", "top_p", "max_new_tokens"} | ||
) | ||
) | ||
def chatbot(user_input, history, temperature, top_p, max_new_tokens): | ||
model, tokenizer = _load_model_and_tokenizer() | ||
response, history = model.chat( | ||
tokenizer, | ||
user_input, | ||
history=history, | ||
max_new_tokens=max_new_tokens, | ||
top_p=top_p, | ||
temperature=temperature, | ||
) | ||
return history |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
import typing as t | ||
from typing import Any | ||
from dataclasses import dataclass | ||
|
||
import torch | ||
from transformers import ( | ||
Trainer, | ||
AutoModel, | ||
AutoConfig, | ||
AutoTokenizer, | ||
PreTrainedTokenizer, | ||
DataCollatorForSeq2Seq, | ||
Seq2SeqTrainingArguments, | ||
) | ||
from transformers.modeling_utils import unwrap_model, PreTrainedModel | ||
|
||
from starwhale import Dataset, finetune | ||
|
||
try: | ||
from .consts import PT_DIR, BASE_MODEL_DIR | ||
except ImportError: | ||
from consts import PT_DIR, BASE_MODEL_DIR | ||
|
||
# fork from https://github.com/THUDM/ChatGLM3/tree/main/finetune_chatmodel_demo | ||
|
||
|
||
@finetune( | ||
resources={"nvidia.com/gpu": 1}, | ||
require_train_datasets=True, | ||
model_modules=["evaluation", "finetune"], | ||
) | ||
def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None: | ||
# TODO: support multi train datasets | ||
train_dataset = train_datasets[0] | ||
|
||
config = AutoConfig.from_pretrained( | ||
BASE_MODEL_DIR, | ||
trust_remote_code=True, | ||
pre_seq_len=int(os.environ.get("PT_PRE_SEQ_LEN", "128")), | ||
prefix_projection=False, | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True) | ||
model = AutoModel.from_pretrained( | ||
BASE_MODEL_DIR, config=config, trust_remote_code=True | ||
) | ||
|
||
# support finetuning from p-tuned model | ||
pt_bin_path = PT_DIR / PrefixTrainer.WEIGHTS_NAME | ||
if pt_bin_path.exists(): | ||
print(f"load p-tuning model: {pt_bin_path}") | ||
prefix_state_dict = torch.load(pt_bin_path) | ||
new_prefix_state_dict = {} | ||
for k, v in prefix_state_dict.items(): | ||
if k.startswith("transformer.prefix_encoder."): | ||
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v | ||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) | ||
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) | ||
|
||
print("Quantized to 4bit...") | ||
model = model.quantize(4).half() | ||
model.transformer.prefix_encoder.float() | ||
model.gradient_checkpointing_enable() | ||
model.enable_input_require_grads() | ||
|
||
trainer = PrefixTrainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
args=Seq2SeqTrainingArguments( | ||
output_dir=str(PT_DIR), | ||
report_to="none", | ||
logging_steps=10, | ||
per_device_train_batch_size=1, # more batch size will cause OOM | ||
gradient_accumulation_steps=16, | ||
save_strategy="no", # no need to save checkpoint for finetune | ||
learning_rate=2e-2, | ||
max_steps=int(os.environ.get("MAX_STEPS", 18)), | ||
num_train_epochs=int(os.environ.get("NUM_TRAIN_EPOCHS", 2)), | ||
gradient_checkpointing=False, | ||
remove_unused_columns=False, | ||
), | ||
train_dataset=train_dataset.to_pytorch( | ||
transform=MultiTurnDataTransform( | ||
tokenizer=tokenizer, | ||
max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)), | ||
) | ||
), | ||
data_collator=DataCollatorForSeq2Seq( | ||
tokenizer=tokenizer, | ||
model=model, | ||
label_pad_token_id=-100, | ||
pad_to_multiple_of=None, | ||
padding=False, | ||
), | ||
save_changed=True, | ||
) | ||
|
||
print("start model training...") | ||
train_result = trainer.train(resume_from_checkpoint=None) | ||
print(train_result.metrics) | ||
trainer.save_state() | ||
trainer.save_model() | ||
|
||
|
||
class PrefixTrainer(Trainer): | ||
WEIGHTS_NAME = "pytorch_model.bin" | ||
TRAINING_ARGS_NAME = "training_args.bin" | ||
|
||
def __init__(self, *args, save_changed=False, **kwargs): | ||
self.save_changed = save_changed | ||
super().__init__(*args, **kwargs) | ||
|
||
def _save(self, output_dir: t.Optional[str] = None, state_dict=None): | ||
# If we are executing this function, we are the process zero, so we don't check for that. | ||
output_dir = output_dir if output_dir is not None else self.args.output_dir | ||
os.makedirs(output_dir, exist_ok=True) | ||
print(f"Saving model checkpoint to {output_dir}") | ||
# Save a trained model and configuration using `save_pretrained()`. | ||
# They can then be reloaded using `from_pretrained()` | ||
if not isinstance(self.model, PreTrainedModel): | ||
if isinstance(unwrap_model(self.model), PreTrainedModel): | ||
if state_dict is None: | ||
state_dict = self.model.state_dict() | ||
unwrap_model(self.model).save_pretrained( | ||
output_dir, state_dict=state_dict | ||
) | ||
else: | ||
print( | ||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict." | ||
) | ||
if state_dict is None: | ||
state_dict = self.model.state_dict() | ||
torch.save(state_dict, os.path.join(output_dir, self.WEIGHTS_NAME)) | ||
else: | ||
if self.save_changed: | ||
print("Saving PrefixEncoder") | ||
state_dict = self.model.state_dict() | ||
filtered_state_dict = {} | ||
for k, v in self.model.named_parameters(): | ||
if v.requires_grad: | ||
filtered_state_dict[k] = state_dict[k] | ||
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) | ||
else: | ||
print("Saving the whole model") | ||
self.model.save_pretrained(output_dir, state_dict=state_dict) | ||
if self.tokenizer is not None: | ||
self.tokenizer.save_pretrained(output_dir) | ||
|
||
# Good practice: save your training arguments together with the trained model | ||
torch.save(self.args, os.path.join(output_dir, self.TRAINING_ARGS_NAME)) | ||
|
||
|
||
@dataclass | ||
class MultiTurnDataTransform: | ||
tokenizer: PreTrainedTokenizer | ||
max_seq_len: int | ||
|
||
ignore_index = -100 | ||
|
||
def __call__(self, example: t.Dict) -> Any: | ||
# belle_chat_random_10k dataset: https://cloud.starwhale.cn/projects/401/datasets/164/versions/226/files | ||
tokens = [ | ||
self.tokenizer.get_command("[gMASK]"), | ||
self.tokenizer.get_command("sop"), | ||
] | ||
loss_masks = [0, 0] | ||
|
||
for message in example["conversations"]: | ||
# belle roles: human and gpt | ||
# ChatGLM3 roles: user and assistant | ||
_role = "user" if message["from"] == "human" else "assistant" | ||
_message_tokens = self.tokenizer.build_single_message( | ||
_role, "", message["value"] | ||
) | ||
tokens.extend(_message_tokens) | ||
loss_masks.extend([0] * len(_message_tokens)) | ||
|
||
tokens.extend([self.tokenizer.eos_token_id]) | ||
loss_masks.extend([0]) | ||
|
||
# labels are used inside the model | ||
target_based_loss_mask = [False] + loss_masks[:-1] | ||
labels = [ | ||
(t if m else self.ignore_index) | ||
for t, m in zip(tokens, target_based_loss_mask) | ||
] | ||
|
||
tokens = tokens[: self.max_seq_len] | ||
labels = labels[: self.max_seq_len] | ||
tokens += [self.tokenizer.pad_token_id] * (self.max_seq_len - len(tokens)) | ||
labels += [self.ignore_index] * (self.max_seq_len - len(labels)) | ||
|
||
return {"input_ids": tokens, "labels": labels} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
name: chatglm3-6b | ||
run: | ||
modules: | ||
- evaluation | ||
- finetune |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.