Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Phi #132

Merged
merged 12 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from lorax_server.models.flash_llama import FlashLlama
from lorax_server.models.flash_gpt2 import FlashGPT2
from lorax_server.models.flash_qwen import FlashQwen
from lorax_server.models.flash_phi import FlashPhi
from lorax_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
Expand All @@ -65,7 +66,9 @@
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashGPT2)
__all__.append(FlashQwen)
__all__.append(FlashPhi)

MISTRAL = True
try:
Expand Down Expand Up @@ -325,6 +328,19 @@ def get_model(
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Qwen model requires flash attention v2")

if model_type == "phi-msft" or model_type == "phi":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about using model_type in [] instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if FLASH_ATTENTION:
return FlashPhi(
model_id,
adapter_id,
adapter_source,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Phi model requires flash attention v2")

if model_type == "opt":
return OPTSharded(
Expand Down
Loading
Loading