Skip to content

Commit

Permalink
refactor inference, warn if model is frozen
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed May 7, 2023
1 parent cb9a887 commit 247825b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
16 changes: 13 additions & 3 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import signal
import sys
from pathlib import Path
from typing import Optional

import fire
import torch
import transformers
import yaml
from attrdict import AttrDefault

Expand Down Expand Up @@ -46,6 +48,15 @@ def get_device():
cfg.device_map = {"": cfg.device}


def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + Z to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction


def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
Expand All @@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):

while True:
# support for multiline inputs
print("Give me an instruction (Ctrl + D to finish): ")
instruction = pathlib.Path("/proc/self/fd/0").read_text()
instruction = get_multi_line_input()
if not instruction:
return
prompt = prompter_module().build_prompt(instruction=instruction)
Expand All @@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to("cuda"),
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def load_model(
model.is_parallelizable = True
model.model_parallel = True

requires_grad = []
for name, param in model.named_parameters(recurse=True):
if param.requires_grad:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
logging.warning("there are no parameters that require gradient updates")

# TODO resume_from_checkpoint handling
return model, tokenizer, lora_config
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else None,
lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
**training_arguments_kwargs,
)

Expand Down

0 comments on commit 247825b

Please sign in to comment.