Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 28, 2024
1 parent 1292734 commit 592fa66
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
15 changes: 7 additions & 8 deletions src/lightning/app/utilities/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from packaging import version


ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
Expand All @@ -22,9 +21,8 @@ def find_adapter_config_file(
subfolder: str = "",
_commit_hash: Optional[str] = None,
) -> Optional[str]:
r"""
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the adapter
config file if it is, None otherwise.
r"""Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path of the
adapter config file if it is, None otherwise.
Args:
model_id (`str`):
Expand Down Expand Up @@ -59,29 +57,30 @@ def find_adapter_config_file(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
"""
adapter_cached_filename = None
if model_id is None:
return None
elif os.path.isdir(model_id):
if os.path.isdir(model_id):
list_remote_files = os.listdir(model_id)
if ADAPTER_CONFIG_NAME in list_remote_files:
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
return adapter_cached_filename


def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
r"""Checks if the version of PEFT is compatible.
Args:
version (`str`):
The version of PEFT to check against.
"""
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version)

if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
)
2 changes: 1 addition & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _load_from_checkpoint(
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
with open(_adapter_model_path, encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
Expand Down

0 comments on commit 592fa66

Please sign in to comment.