Skip to content

Commit

Permalink
example: make chatglm6b model small (#2271)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored May 24, 2023
1 parent 81f2887 commit 9b74ec9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
11 changes: 9 additions & 2 deletions example/LLM/chatglm6b/download_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from pathlib import Path

from transformers import AutoModel, AutoTokenizer

ROOTDIR = Path(__file__).parent
print(str(ROOTDIR / "models"))
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = (
AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
)
tokenizer.save_pretrained("./models")
model.save_pretrained("./models")
print(str(ROOTDIR / "models"))
tokenizer.save_pretrained(str(ROOTDIR / "models"))
model.save_pretrained(str(ROOTDIR / "models"))
del model
del tokenizer
8 changes: 7 additions & 1 deletion example/LLM/chatglm6b/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@evaluation.predict(
log_mode="plain",
log_dataset_features=["query", "text", "question", "rawquestion", "prompt"],
replicas=1,
)
def ppl(data: dict, external: dict):
ds_name = external["dataset_uri"].name
Expand All @@ -44,6 +45,8 @@ def ppl(data: dict, external: dict):
text = data["query"]
else:
raise ValueError(f"dataset {ds_name} does not fit this model")
if not os.path.exists(ROOTDIR / "models"):
import download_model # noqa: F401
global tokenizer
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -54,6 +57,7 @@ def ppl(data: dict, external: dict):
chatglm = AutoModel.from_pretrained(
str(ROOTDIR / "models"), trust_remote_code=True
)

if os.path.exists(ROOTDIR / "models" / "chatglm-6b-lora.pt"):
chatglm = load_lora_config(chatglm)
chatglm.load_state_dict(
Expand Down Expand Up @@ -267,6 +271,8 @@ def save_tuned_parameters(model, path):
def fine_tune(
context: Context,
) -> None:
if not os.path.exists(ROOTDIR / "models"):
import download_model # noqa: F401
tokenizer = AutoTokenizer.from_pretrained(
str(ROOTDIR / "models"), trust_remote_code=True
)
Expand All @@ -278,7 +284,7 @@ def fine_tune(
)
sw_dataset = dataset(context.dataset_uris[0], readonly=True, create="forbid")
sw_dataset = sw_dataset.with_loader_config(
field_transformer=ds_key_selectors.get(sw_dataset.name, None)
field_transformer=ds_key_selectors.get(sw_dataset._uri.name, None)
)
train_dataset = QADataset(
sw_dataset,
Expand Down

0 comments on commit 9b74ec9

Please sign in to comment.