From 247825bd5727b31127bd0f4175df13e73ba01e00 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 7 May 2023 01:53:30 -0400 Subject: [PATCH] refactor inference, warn if model is frozen --- scripts/finetune.py | 16 +++++++++++++--- src/axolotl/utils/models.py | 6 ++++++ src/axolotl/utils/trainer.py | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index a8cfe2a03e..67e23f35c2 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,9 +6,11 @@ import signal import sys from pathlib import Path +from typing import Optional import fire import torch +import transformers import yaml from attrdict import AttrDefault @@ -46,6 +48,15 @@ def get_device(): cfg.device_map = {"": cfg.device} +def get_multi_line_input() -> Optional[str]: + print("Give me an instruction (Ctrl + Z to finish): ") + instruction = "" + for line in sys.stdin: + instruction += line + # instruction = pathlib.Path("/proc/self/fd/0").read_text() + return instruction + + def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): tokenizer.add_special_tokens({"unk_token": ""}) tokenizer.add_special_tokens({"bos_token": ""}) @@ -55,8 +66,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): while True: # support for multiline inputs - print("Give me an instruction (Ctrl + D to finish): ") - instruction = pathlib.Path("/proc/self/fd/0").read_text() + instruction = get_multi_line_input() if not instruction: return prompt = prompter_module().build_prompt(instruction=instruction) @@ -66,7 +76,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): with torch.no_grad(): # gc = GenerationConfig() # TODO swap out and use this generated = model.generate( - inputs=batch["input_ids"].to("cuda"), + inputs=batch["input_ids"].to(cfg.device), do_sample=True, use_cache=True, repetition_penalty=1.1, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1fc47a87fb..60476d8970 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -183,6 +183,12 @@ def load_model( model.is_parallelizable = True model.model_parallel = True + requires_grad = [] + for name, param in model.named_parameters(recurse=True): + if param.requires_grad: + requires_grad.append(f"{name}: {param.requires_grad}") + if len(requires_grad) == 0: + logging.warning("there are no parameters that require gradient updates") # TODO resume_from_checkpoint handling return model, tokenizer, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 3c6aca1791..12fe93fe42 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -105,7 +105,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): run_name=cfg.wandb_run_id if cfg.use_wandb else None, optim=cfg.optimizer if cfg.optimizer else None, lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine", - weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0, + weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, **training_arguments_kwargs, )