Skip to content

Commit

Permalink
add option to disable version check
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 10, 2024
1 parent a754f6e commit 91d09a0
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 26 deletions.
26 changes: 15 additions & 11 deletions src/llmtuner/extras/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,52 @@
import importlib.util


def is_package_available(name: str) -> bool:
def _is_package_available(name: str) -> bool:
return importlib.util.find_spec(name) is not None


def get_package_version(name: str) -> str:
def _get_package_version(name: str) -> str:
try:
return importlib.metadata.version(name)
except Exception:
return "0.0.0"


def is_fastapi_availble():
return is_package_available("fastapi")
return _is_package_available("fastapi")


def is_flash_attn2_available():
return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
return _is_package_available("flash_attn") and _get_package_version("flash_attn").startswith("2")


def is_jieba_available():
return is_package_available("jieba")
return _is_package_available("jieba")


def is_matplotlib_available():
return is_package_available("matplotlib")
return _is_package_available("matplotlib")


def is_nltk_available():
return is_package_available("nltk")
return _is_package_available("nltk")


def is_requests_available():
return is_package_available("requests")
return _is_package_available("requests")


def is_rouge_available():
return is_package_available("rouge_chinese")
return _is_package_available("rouge_chinese")


def is_starlette_available():
return is_package_available("sse_starlette")
return _is_package_available("sse_starlette")


def is_unsloth_available():
return _is_package_available("unsloth")


def is_uvicorn_available():
return is_package_available("uvicorn")
return _is_package_available("uvicorn")
3 changes: 3 additions & 0 deletions src/llmtuner/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
default="lora", metadata={"help": "Which fine-tuning method to use."}
)
disable_version_checking: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to disable version checking."}
)
plot_loss: Optional[bool] = field(
default=False, metadata={"help": "Whether or not to save the training loss curves."}
)
Expand Down
35 changes: 29 additions & 6 deletions src/llmtuner/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import transformers
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.versions import require_version

from ..extras.logging import get_logger
from ..extras.packages import is_unsloth_available
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
Expand All @@ -28,6 +30,14 @@
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]


def _check_dependencies():
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")


def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
Expand Down Expand Up @@ -123,8 +133,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")

if training_args.do_train and model_args.use_unsloth and not is_unsloth_available:
raise ValueError("Install Unsloth: https://github.com/unslothai/unsloth")

_verify_model_args(model_args, finetuning_args)

if not finetuning_args.disable_version_checking:
_check_dependencies()

if (
training_args.do_train
and finetuning_args.finetuning_type == "lora"
Expand All @@ -145,7 +161,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
logger.warning("Specify `ref_model` for computing rewards at evaluation.")

# postprocess training_args
# Post-process training arguments
if (
training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None
Expand All @@ -158,7 +174,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:

if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
can_resume_from_checkpoint = False
training_args.resume_from_checkpoint = None
if training_args.resume_from_checkpoint is not None:
logger.warning("Cannot resume from checkpoint in current stage.")
training_args.resume_from_checkpoint = None
else:
can_resume_from_checkpoint = True

Expand Down Expand Up @@ -194,7 +212,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
)

# postprocess model_args
# Post-process model arguments
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
Expand All @@ -212,32 +230,37 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
)
logger.info(f"Training/evaluation parameters {training_args}")

# Set seed before initializing model.
transformers.set_seed(training_args.seed)

return model_args, data_args, training_args, finetuning_args, generating_args


def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)

_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)

if data_args.template is None:
raise ValueError("Please specify which `template` to use.")

_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()

return model_args, data_args, finetuning_args, generating_args


def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)

_set_transformers_logging()
_verify_model_args(model_args, finetuning_args)

if data_args.template is None:
raise ValueError("Please specify which `template` to use.")

_verify_model_args(model_args, finetuning_args)
if not finetuning_args.disable_version_checking:
_check_dependencies()

transformers.set_seed(eval_args.seed)

Expand Down
9 changes: 0 additions & 9 deletions src/llmtuner/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead

from ..extras.logging import get_logger
Expand All @@ -21,13 +20,6 @@
logger = get_logger(__name__)


require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")


def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
Expand Down Expand Up @@ -63,7 +55,6 @@ def load_model_and_tokenizer(

model = None
if is_trainable and model_args.use_unsloth:
require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
from unsloth import FastLlamaModel, FastMistralModel # type: ignore

unsloth_kwargs = {
Expand Down

0 comments on commit 91d09a0

Please sign in to comment.