Skip to content

refactor the code and add warp_in_hpu_graph to corner case #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 54 additions & 87 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
__all__ = ["Model"]

TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
"true",
"1",
]
# Disable gradients
torch.set_grad_enabled(False)

Expand All @@ -32,6 +36,29 @@
__all__.append(FlashBert)


def wrap_model_if_hpu(model_handle, device):
"""Wrap the model in HPU graph if the device is HPU."""
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model_handle.model = wrap_in_hpu_graph(
model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE
)
return model_handle


def create_model(model_class, model_path, device, datatype, pool="cls"):
"""Create a model instance and wrap it if needed."""
model_handle = model_class(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return wrap_model_if_hpu(model_handle, device)


def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
datatype = torch.float32
Expand All @@ -46,6 +73,7 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
logger.info(f"backend device: {device}")

config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

if (
hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
Expand All @@ -54,8 +82,9 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return FlashJinaBert(model_path, config, device, datatype, pool)
elif config.model_type == "bert":
return create_model(FlashJinaBert, model_path, device, datatype)

if config.model_type == "bert":
config: BertConfig
if (
use_ipex()
Expand All @@ -66,98 +95,36 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
):
if pool != "cls":
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return MaskedLanguageModel(
model_path,
device,
datatype,
trust_remote=TRUST_REMOTE_CODE,
return create_model(
MaskedLanguageModel, model_path, device, datatype, pool
)
return DefaultModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
return create_model(DefaultModel, model_path, device, datatype, pool)

try:
return FlashBert(model_path, device, datatype)
except FileNotFoundError as e:
return create_model(FlashBert, model_path, device, datatype)
except FileNotFoundError:
logger.info(
"Do not have safetensors file for this model, use default transformers model path instead"
)
return DefaultModel(
model_path, device, datatype, pool, trust_remote=TRUST_REMOTE_CODE
)
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return MaskedLanguageModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.model_type == "mistral" and device.type == "hpu":
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "mistral" and device.type == "hpu":
try:
return FlashMistral(
model_path,
device,
datatype,
pool,
)
except FileNotFoundError as e:
return DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

if config.architectures[0].endswith("Classification"):
model_handle = ClassificationModel(
model_path,
device,
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
model_handle = MaskedLanguageModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
else:
model_handle = DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
model_handle.model = wrap_in_hpu_graph(model_handle.model)
return model_handle
elif use_ipex():
if config.architectures[0].endswith("Classification"):
return ClassificationModel(
model_path,
device,
datatype,
trust_remote=TRUST_REMOTE_CODE,
)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return MaskedLanguageModel(
model_path, device, datatype, trust_remote=TRUST_REMOTE_CODE
)
else:
return DefaultModel(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return create_model(DefaultModel, model_path, device, datatype, pool)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
model = AutoModelForSequenceClassification.from_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str,
pool: str = "cls",
trust_remote: bool = False,
):
model = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,14 @@ def forward(


class FlashBert(Model):
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
config = BertConfig.from_pretrained(model_path)

if hasattr(config, "max_seq_length"):
Expand All @@ -306,10 +313,6 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
model = FlashBertModel(f, device, dtype, config)
self.device = device
self.dtype = dtype
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,12 @@ def forward(

class FlashMistral(Model):
def __init__(
self, model_path: Path, device: torch.device, dtype: torch.dtype, pool: str
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
config = MistralConfig.from_pretrained(model_path)

Expand All @@ -379,10 +384,6 @@ def __init__(
model = FlashMistralModel(model_path, index_data, device, dtype, config)
self.device = device
self.dtype = dtype
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

super(FlashMistral, self).__init__(model=model, dtype=dtype, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,12 @@ class FlashJinaBert(Model):
def __init__(
self,
model_path: Path,
config: AutoConfig,
device: torch.device,
dtype: torch.dtype,
pool: str,
pool: str = "mean",
trust_remote: bool = True,
):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote)
if hasattr(config, "max_seq_length"):
self.max_input_length = config.max_seq_length
else:
Expand All @@ -494,10 +495,6 @@ def __init__(
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
self.device = device
self.dtype = dtype
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

super(FlashJinaBert, self).__init__(model=model, dtype=dtype, device=device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
model = (
Expand Down