Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge google/flan based adapters: T5Adapter, CodeT5pAdapter, FlanAdapter #2411

Merged
merged 9 commits into from
Sep 18, 2023
40 changes: 8 additions & 32 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,14 @@
AutoTokenizer,
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
)

from fastchat.constants import CPU_ISA
from fastchat.modules.gptq import GptqConfig, load_gptq_quantized
from fastchat.modules.awq import AWQConfig, load_awq_quantized
from fastchat.conversation import Conversation, get_conv_template
from fastchat.model.compression import load_compress_model
from fastchat.model.llama_condense_monkey_patch import (
replace_llama_with_condense,
)
from fastchat.model.llama_condense_monkey_patch import replace_llama_with_condense
from fastchat.model.model_chatglm import generate_stream_chatglm
from fastchat.model.model_codet5p import generate_stream_codet5p
from fastchat.model.model_falcon import generate_stream_falcon
Expand Down Expand Up @@ -616,11 +613,14 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class CodeT5pAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/codet5p-6b"""
class GoogleFlanAdapter(BaseModelAdapter):
"""The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "codet5p" in model_path.lower()
return any(
model_str in model_path.lower()
for model_str in ["flan-", "fastchat-t5", "codet5p"]
)

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
Expand All @@ -634,28 +634,6 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class T5Adapter(BaseModelAdapter):
"""The model adapter for lmsys/fastchat-t5-3b-v1.0"""

def match(self, model_path: str):
return "t5" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1599,9 +1577,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(GoogleFlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down