Skip to content

Commit

Permalink
merge google/flan based adapters: T5Adapter, CodeT5pAdapter, FlanAdap…
Browse files Browse the repository at this point in the history
…ter (#2411)
  • Loading branch information
wangzhen263 authored Sep 18, 2023
1 parent 24acac1 commit 30a6ffc
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,14 @@
AutoTokenizer,
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
)

from fastchat.constants import CPU_ISA
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.llama_condense_monkey_patch import (
replace_llama_with_condense,
)
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
Expand Down Expand Up @@ -635,11 +632,14 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class CodeT5pAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/codet5p-6b"""
class GoogleFlanAdapter(BaseModelAdapter):
"""The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "codet5p" in model_path.lower()
return any(
model_str in model_path.lower()
for model_str in ["flan-", "fastchat-t5", "codet5p"]
)

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
Expand All @@ -653,28 +653,6 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class T5Adapter(BaseModelAdapter):
"""The model adapter for lmsys/fastchat-t5-3b-v1.0"""

def match(self, model_path: str):
return "t5" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for Koala"""

Expand Down Expand Up @@ -1636,9 +1614,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(GoogleFlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down

0 comments on commit 30a6ffc

Please sign in to comment.