Skip to content

Commit

Permalink
Merge pull request #3 from retrieva/add-peft
Browse files Browse the repository at this point in the history
Add peft for lora
  • Loading branch information
Katsumata420 authored Feb 16, 2024
2 parents 97c1b7a + bb50304 commit 59b1fa6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
5 changes: 5 additions & 0 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config, MT5Config
from peft import PeftConfig, get_peft_model
import json
from typing import List, Dict, Optional, Union, Tuple
import os
Expand Down Expand Up @@ -27,6 +28,7 @@ def __init__(
tokenizer_args: Dict = {},
do_lower_case: bool = False,
tokenizer_name_or_path: str = None,
peft_config: Optional[PeftConfig] = None,
):
super(Transformer, self).__init__()
self.config_keys = ["max_seq_length", "do_lower_case"]
Expand All @@ -35,6 +37,9 @@ def __init__(
config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir, **model_args)

if peft_config is not None:
self.auto_model = get_peft_model(self.auto_model, peft_config)

self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
cache_dir=cache_dir,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"sentencepiece",
"huggingface-hub>=0.15.1",
"Pillow",
"peft",
],
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down

0 comments on commit 59b1fa6

Please sign in to comment.