Skip to content

Commit

Permalink
example(LLM): fintune bigscience/mt0 (#2314)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Jun 9, 2023
1 parent b7c5abc commit 6cc26ec
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 6 deletions.
11 changes: 7 additions & 4 deletions example/LLM/bloom/download_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


def download():
def download(checkpoint: str = "bigscience/mt0-xxl"):
ROOTDIR = Path(__file__).parent
checkpoint = "bigscience/mt0-xxl"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(
Expand All @@ -18,5 +18,8 @@ def download():
del tokenizer


if __name__ == "main":
download()
if __name__ == "__main__":
if len(sys.argv) == 2:
download(sys.argv[1])
else:
download()
156 changes: 154 additions & 2 deletions example/LLM/bloom/mt0xxl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
# pip install -q transformers accelerate starwhale
import os
from typing import Any
from pathlib import Path

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
import torch
from datasets import Dataset
from transformers import (
AdamW,
AutoTokenizer,
AutoModelForSeq2SeqLM,
get_linear_schedule_with_warmup,
)

from starwhale import evaluation
from starwhale import Context, dataset, fine_tune, evaluation, pass_context
from starwhale.api import model as swmp

ROOTDIR = Path(__file__).parent

Expand All @@ -23,6 +33,7 @@ def ppl(data: dict, external: dict):
from download_model import download

download()

global tokenizer
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
Expand All @@ -49,3 +60,144 @@ def ppl(data: dict, external: dict):
inputs = tokenizer.encode(text, return_tensors="pt").to("cuda")
outputs = model.generate(inputs)
return tokenizer.decode(outputs[0])


ds_key_selectors = {
"webqsp": {
"rawquestion": "instruction",
"parses[0].Answers[0].EntityName": "output",
},
"grailqav1": {"question": "instruction", "answer[0].entity_name": "output"},
"graph_questions_testing": {"question": "instruction", "answer[0]": "output"},
"z_bench_common": {"prompt": "instruction", "gpt4": "output"},
"mkqa": {"query": "instruction", "answers.en[0].text": "output"},
}


@pass_context
@fine_tune()
def ft(context: Context) -> None:
ft_inner(context=context)


def ft_inner(
context: Context = None,
swds: str = "mkqa/version/latest",
) -> None:
checkpoint = str(ROOTDIR / "models")
if not os.path.exists(checkpoint):
from download_model import download

download()

tokeniser = AutoTokenizer.from_pretrained(
str(ROOTDIR / "models"), trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
checkpoint, torch_dtype="auto", device_map="auto"
)
max_length = 100
swds_name = context.dataset_uris[0] if context else swds
sw_dataset = dataset(swds_name, readonly=True, create="forbid")
sw_dataset = sw_dataset.with_loader_config(
field_transformer=ds_key_selectors.get(sw_dataset._uri.name, None)
)
hgds = swds2hgds(sw_dataset)
hgds = (
hgds.shuffle()
.map(
lambda elem: {
"input_ids": tokeniser.encode(
elem.get("instruction", "") or "",
padding="max_length",
truncation=True,
max_length=max_length,
),
"labels": tokeniser.encode(
elem.get("output", "") or "",
padding="max_length",
truncation=True,
max_length=max_length,
),
# "label": elem["output"],
}
)
.train_test_split(test_size=0.1)
)
batch_size = os.getenv("MT0_TRAIN_BATCH_SIZE") or 16

hgds = hgds["train"]

def ds_gen():
current_item = 0
while True:
start = current_item
current_item += batch_size
if current_item >= len(hgds):
break
datas = hgds[start:current_item]
yield torch.tensor(datas["input_ids"]).cuda(), torch.tensor(
datas["labels"]
).cuda()

n_epochs = int(os.getenv("MT0_TRAIN_EPOCHS")) or 8
print(f"epochs is {n_epochs}")
# batch_size = 16
print_freq = 50
lr = 5e-4
n_batches = int(np.ceil(len(hgds) / batch_size))
total_steps = n_epochs * n_batches
n_warmup_steps = int(total_steps * 0.01)
# Optimizer
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, n_warmup_steps, total_steps)
losses = []

for epoch_idx in range(n_epochs):
# Randomize data order

for batch_idx, (input_batch, label_batch) in enumerate(ds_gen()):
optimizer.zero_grad()

# Forward pass
model_out = model.forward(input_ids=input_batch, labels=label_batch)

# Calculate loss and update weights
loss = model_out.loss
losses.append(loss.item())
loss.backward()
optimizer.step()
scheduler.step()

# Print training update info
if (batch_idx + 1) % print_freq == 0:
avg_loss = np.mean(losses[-print_freq:])
print(
"Epoch: {} | Step: {} | Avg. loss: {:.3f} | lr: {}".format(
epoch_idx + 1,
batch_idx + 1,
avg_loss,
scheduler.get_last_lr()[0],
)
)

torch.save(model.state_dict(), str(ROOTDIR / "models" / "pytorch_model.bin"))
swmp.build(
workdir=ROOTDIR,
name="mt0",
modules=[ft, ppl],
)


def swds2hgds(swds) -> Any:
sw_ds = swds

def my_gen():
for item in sw_ds:
yield item.features

return Dataset.from_generator(my_gen)


if __name__ == "__main__":
ft_inner()

0 comments on commit 6cc26ec

Please sign in to comment.