Skip to content

Commit

Permalink
Added Phi (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Dec 15, 2023
1 parent f10e0bb commit 549bbb8
Show file tree
Hide file tree
Showing 3 changed files with 585 additions and 0 deletions.
16 changes: 16 additions & 0 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from lorax_server.models.flash_llama import FlashLlama
from lorax_server.models.flash_gpt2 import FlashGPT2
from lorax_server.models.flash_qwen import FlashQwen
from lorax_server.models.flash_phi import FlashPhi
from lorax_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
Expand All @@ -65,7 +66,9 @@
__all__.append(FlashRWSharded)
__all__.append(FlashSantacoderSharded)
__all__.append(FlashLlama)
__all__.append(FlashGPT2)
__all__.append(FlashQwen)
__all__.append(FlashPhi)

MISTRAL = True
try:
Expand Down Expand Up @@ -325,6 +328,19 @@ def get_model(
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Qwen model requires flash attention v2")

if model_type in ["phi-msft", "phi"]:
if FLASH_ATTENTION:
return FlashPhi(
model_id,
adapter_id,
adapter_source,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Phi model requires flash attention v2")

if model_type == "opt":
return OPTSharded(
Expand Down
Loading

0 comments on commit 549bbb8

Please sign in to comment.