Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge google/flan based adapters: T5Adapter, CodeT5pAdapter, FlanAdapter #2411

Merged
merged 9 commits into from
Sep 18, 2023
35 changes: 6 additions & 29 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
AutoTokenizer,
LlamaTokenizer,
LlamaForCausalLM,
T5Tokenizer,
)

from fastchat.constants import CPU_ISA
Expand Down Expand Up @@ -616,11 +615,13 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class CodeT5pAdapter(BaseModelAdapter):
"""The model adapter for Salesforce/codet5p-6b"""
class GoogleFlanAdapter(BaseModelAdapter):
"""The model adapter for google/Flan based models, such as Salesforce/codet5p-6b, lmsys/fastchat-t5-3b-v1.0, flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "codet5p" in model_path.lower()
return any(
model_path in model_str for model_str in ["flan-", "fastchat-t5", "codet5p"]
wangzhen263 marked this conversation as resolved.
Show resolved Hide resolved
)

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
Expand All @@ -634,28 +635,6 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer


class T5Adapter(BaseModelAdapter):
"""The model adapter for lmsys/fastchat-t5-3b-v1.0"""

def match(self, model_path: str):
return "t5" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
tokenizer = T5Tokenizer.from_pretrained(model_path, revision=revision)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
)
return model, tokenizer


class FlanAdapter(T5Adapter):
"""The model adapter for flan-t5-*, flan-ul2"""

def match(self, model_path: str):
return "flan" in model_path.lower()


class KoalaAdapter(BaseModelAdapter):
"""The model adapter for koala"""

Expand Down Expand Up @@ -1599,9 +1578,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)
register_model_adapter(LongChatAdapter)
register_model_adapter(CodeT5pAdapter)
register_model_adapter(T5Adapter)
register_model_adapter(FlanAdapter)
register_model_adapter(GoogleFlanAdapter)
register_model_adapter(KoalaAdapter)
register_model_adapter(AlpacaAdapter)
register_model_adapter(ChatGLMAdapter)
Expand Down