Skip to content

Commit

Permalink
example(finetune): add chatglm3 evaluation and finetune example (#3025)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Nov 28, 2023
1 parent 8cd4de5 commit 85c2f87
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 3 deletions.
1 change: 0 additions & 1 deletion example/llm-finetune/models/baichuan2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,3 @@ swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichu

swcli -vvv model run -w . -m finetune --dataset belle_chat_random_10k --handler finetune:lora_finetune
```

2 changes: 2 additions & 0 deletions example/llm-finetune/models/chatglm3/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pretrain/
.cache/
1 change: 1 addition & 0 deletions example/llm-finetune/models/chatglm3/.swignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.cache/
59 changes: 59 additions & 0 deletions example/llm-finetune/models/chatglm3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
ChatGLM3 Finetune with Starwhale
======

- 🍬 Parameters: 6b
- 🔆 Github: https://github.com/THUDM/ChatGLM3
- 🥦 Author: THUDM.
- 📝 License: Unknown
- 🐱 Starwhale Example: https://github.com/star-whale/starwhale/tree/main/example/llm-finetune/models/chatglm3
- 🌽 Introduction: ChatGLM3 is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series, maintaining many excellent features of the first two generations such as smooth dialogue and low deployment threshold.

In this example, we use 4bit quantization to reduce gpu memory usage, the single T4/A10/A100 gpu card is ok for evaluation and finetune.

Build Starwhale Model
------

```bash
python3 download.py
swcli model build .
```

Run Online Evaluation in the Standalone instance
------

```bash
# for source code
swcli -vvv model serve --workdir . --host 0.0.0.0 --port 10878

# for model package with runtime
swcli -vvv model serve --uri chatglm3-6b --host 0.0.0.0 --port 10878 --runtime llm-finetune
```

Run Starwhale Model for evaluation in the Standalone instance
------

```bash
# download evaluation dataset
swcli dataset cp https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/ .

# for source code
swcli -vvv model run -w . -m evaluation --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3

# for model package
swcli -vvv model run --uri chatglm3-6b --handler evaluation:copilot_predict --dataset z-bench-common --dataset-head 3 --runtime llm-finetune
```


Finetune base model
------

```bash
# build finetune dataset from baichuan2
swcli dataset build --json https://raw.githubusercontent.com/baichuan-inc/Baichuan2/main/fine-tune/data/belle_chat_ramdon_10k.json --name belle_chat_random_10k

# for source code
swcli -vvv model run -w . -m finetune --dataset belle_chat_random_10k --handler finetune:p_tuning_v2_finetune

# for model package
swcli -vvv model run -u chatglm3-6b --dataset belle_chat_random_10k --handler finetune:p_tuning_v2_finetune --runtime llm-finetune
```
6 changes: 6 additions & 0 deletions example/llm-finetune/models/chatglm3/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pathlib import Path

ROOT_DIR = Path(__file__).parent
PRETRAINED_MODELS_DIR = ROOT_DIR / "pretrained_models"
BASE_MODEL_DIR = PRETRAINED_MODELS_DIR / "base"
PT_DIR = PRETRAINED_MODELS_DIR / "p_tuning_v2"
9 changes: 9 additions & 0 deletions example/llm-finetune/models/chatglm3/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from huggingface_hub import snapshot_download

try:
from .consts import BASE_MODEL_DIR
except ImportError:
from consts import BASE_MODEL_DIR

if __name__ == "__main__":
snapshot_download(repo_id="THUDM/chatglm3-6b", local_dir=BASE_MODEL_DIR)
96 changes: 96 additions & 0 deletions example/llm-finetune/models/chatglm3/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import os
import typing as t

import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer

from starwhale import evaluation
from starwhale.api.service import api, LLMChat

try:
from .consts import PT_DIR, BASE_MODEL_DIR
except ImportError:
from consts import PT_DIR, BASE_MODEL_DIR

_g_model = None
_g_tokenizer = None


def _load_model_and_tokenizer() -> t.Tuple:
global _g_model, _g_tokenizer

if _g_model is None:
# TODO: after starwhale supports parameters, we can remove os environ.
config = AutoConfig.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
pre_seq_len=int(os.environ.get("PT_PRE_SEQ_LEN", "128")),
)
_g_model = (
AutoModel.from_pretrained(
BASE_MODEL_DIR,
config=config,
device_map="cuda:0",
torch_dtype=torch.float16,
trust_remote_code=True,
)
.quantize(4)
.cuda()
.eval()
)

ptuning_path = PT_DIR / "pytorch_model.bin"
if ptuning_path.exists():
print(f"load p-tuning model: {ptuning_path}")
prefix_state_dict = torch.load(ptuning_path)
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
_g_model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

if _g_tokenizer is None:
_g_tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_DIR, use_fast=False, trust_remote_code=True
)

return _g_model, _g_tokenizer


@evaluation.predict(
resources={"nvidia.com/gpu": 1},
replicas=1,
log_mode="plain",
)
def copilot_predict(data: dict) -> str:
model, tokenizer = _load_model_and_tokenizer()
print(data["prompt"])
response, _ = model.chat(
tokenizer,
data["prompt"],
history=[],
max_length=int(os.environ.get("MAX_LENGTH", "512")),
top_p=float(os.environ.get("TOP_P", "0.9")),
temperature=float(os.environ.get("TEMPERATURE", "1.2")),
)
return response


@api(
inference_type=LLMChat(
args={"user_input", "history", "temperature", "top_p", "max_new_tokens"}
)
)
def chatbot(user_input, history, temperature, top_p, max_new_tokens):
model, tokenizer = _load_model_and_tokenizer()
response, history = model.chat(
tokenizer,
user_input,
history=history,
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
)
return history
195 changes: 195 additions & 0 deletions example/llm-finetune/models/chatglm3/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import os
import typing as t
from typing import Any
from dataclasses import dataclass

import torch
from transformers import (
Trainer,
AutoModel,
AutoConfig,
AutoTokenizer,
PreTrainedTokenizer,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
)
from transformers.modeling_utils import unwrap_model, PreTrainedModel

from starwhale import Dataset, finetune

try:
from .consts import PT_DIR, BASE_MODEL_DIR
except ImportError:
from consts import PT_DIR, BASE_MODEL_DIR

# fork from https://github.com/THUDM/ChatGLM3/tree/main/finetune_chatmodel_demo


@finetune(
resources={"nvidia.com/gpu": 1},
require_train_datasets=True,
model_modules=["evaluation", "finetune"],
)
def p_tuning_v2_finetune(train_datasets: t.List[Dataset]) -> None:
# TODO: support multi train datasets
train_dataset = train_datasets[0]

config = AutoConfig.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
pre_seq_len=int(os.environ.get("PT_PRE_SEQ_LEN", "128")),
prefix_projection=False,
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True)
model = AutoModel.from_pretrained(
BASE_MODEL_DIR, config=config, trust_remote_code=True
)

# support finetuning from p-tuned model
pt_bin_path = PT_DIR / PrefixTrainer.WEIGHTS_NAME
if pt_bin_path.exists():
print(f"load p-tuning model: {pt_bin_path}")
prefix_state_dict = torch.load(pt_bin_path)
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder.") :]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

print("Quantized to 4bit...")
model = model.quantize(4).half()
model.transformer.prefix_encoder.float()
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

trainer = PrefixTrainer(
model=model,
tokenizer=tokenizer,
args=Seq2SeqTrainingArguments(
output_dir=str(PT_DIR),
report_to="none",
logging_steps=10,
per_device_train_batch_size=1, # more batch size will cause OOM
gradient_accumulation_steps=16,
save_strategy="no", # no need to save checkpoint for finetune
learning_rate=2e-2,
max_steps=int(os.environ.get("MAX_STEPS", 18)),
num_train_epochs=int(os.environ.get("NUM_TRAIN_EPOCHS", 2)),
gradient_checkpointing=False,
remove_unused_columns=False,
),
train_dataset=train_dataset.to_pytorch(
transform=MultiTurnDataTransform(
tokenizer=tokenizer,
max_seq_len=int(os.environ.get("MAX_SEQ_LEN", 2048)),
)
),
data_collator=DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=None,
padding=False,
),
save_changed=True,
)

print("start model training...")
train_result = trainer.train(resume_from_checkpoint=None)
print(train_result.metrics)
trainer.save_state()
trainer.save_model()


class PrefixTrainer(Trainer):
WEIGHTS_NAME = "pytorch_model.bin"
TRAINING_ARGS_NAME = "training_args.bin"

def __init__(self, *args, save_changed=False, **kwargs):
self.save_changed = save_changed
super().__init__(*args, **kwargs)

def _save(self, output_dir: t.Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
print(f"Saving model checkpoint to {output_dir}")
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel):
if state_dict is None:
state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict
)
else:
print(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, self.WEIGHTS_NAME))
else:
if self.save_changed:
print("Saving PrefixEncoder")
state_dict = self.model.state_dict()
filtered_state_dict = {}
for k, v in self.model.named_parameters():
if v.requires_grad:
filtered_state_dict[k] = state_dict[k]
self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
else:
print("Saving the whole model")
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, self.TRAINING_ARGS_NAME))


@dataclass
class MultiTurnDataTransform:
tokenizer: PreTrainedTokenizer
max_seq_len: int

ignore_index = -100

def __call__(self, example: t.Dict) -> Any:
# belle_chat_random_10k dataset: https://cloud.starwhale.cn/projects/401/datasets/164/versions/226/files
tokens = [
self.tokenizer.get_command("[gMASK]"),
self.tokenizer.get_command("sop"),
]
loss_masks = [0, 0]

for message in example["conversations"]:
# belle roles: human and gpt
# ChatGLM3 roles: user and assistant
_role = "user" if message["from"] == "human" else "assistant"
_message_tokens = self.tokenizer.build_single_message(
_role, "", message["value"]
)
tokens.extend(_message_tokens)
loss_masks.extend([0] * len(_message_tokens))

tokens.extend([self.tokenizer.eos_token_id])
loss_masks.extend([0])

# labels are used inside the model
target_based_loss_mask = [False] + loss_masks[:-1]
labels = [
(t if m else self.ignore_index)
for t, m in zip(tokens, target_based_loss_mask)
]

tokens = tokens[: self.max_seq_len]
labels = labels[: self.max_seq_len]
tokens += [self.tokenizer.pad_token_id] * (self.max_seq_len - len(tokens))
labels += [self.ignore_index] * (self.max_seq_len - len(labels))

return {"input_ids": tokens, "labels": labels}
5 changes: 5 additions & 0 deletions example/llm-finetune/models/chatglm3/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: chatglm3-6b
run:
modules:
- evaluation
- finetune
3 changes: 2 additions & 1 deletion example/llm-finetune/runtime/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ bitsandbytes
colorama
tokenizers
sentencepiece
git+https://github.com/star-whale/starwhale.git@bebd503#subdirectory=client&egg=starwhale
astunparse
git+https://github.com/star-whale/starwhale.git@c57144a#subdirectory=client&egg=starwhale
Loading

0 comments on commit 85c2f87

Please sign in to comment.