Skip to content

Commit

Permalink
added deprecation warnings to ml module
Browse files Browse the repository at this point in the history
  • Loading branch information
LEFTA98 committed Aug 2, 2022
1 parent 727c0db commit 35ccb3b
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 28 deletions.
6 changes: 6 additions & 0 deletions eland/ml/_model_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from abc import ABC
from typing import Any, Dict, List, Optional, Sequence

from warnings import warn


def add_if_exists(d: Dict[str, Any], k: str, v: Any) -> None:
if v is not None:
Expand All @@ -34,6 +36,7 @@ def __init__(
target_type: Optional[str] = None,
classification_labels: Optional[Sequence[str]] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._target_type = target_type
self._feature_names = feature_names
self._classification_labels = classification_labels
Expand Down Expand Up @@ -72,6 +75,7 @@ def __init__(
leaf_value: Optional[List[float]] = None,
number_samples: Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._node_idx = node_idx
self._decision_type = decision_type
self._left_child = left_child
Expand Down Expand Up @@ -114,6 +118,7 @@ def __init__(
tree_structure: Optional[Sequence[TreeNode]] = None,
classification_labels: Optional[Sequence[str]] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(
feature_names=feature_names,
target_type=target_type,
Expand All @@ -139,6 +144,7 @@ def __init__(
classification_labels: Optional[Sequence[str]] = None,
classification_weights: Optional[Sequence[float]] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(
feature_names=feature_names,
target_type=target_type,
Expand Down
4 changes: 3 additions & 1 deletion eland/ml/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from .common import TYPE_CLASSIFICATION, TYPE_REGRESSION
from .transformers import get_model_transformer

from warnings import warn

if TYPE_CHECKING:
from elasticsearch import Elasticsearch
from numpy.typing import ArrayLike, DTypeLike
Expand All @@ -52,7 +54,6 @@
except ImportError:
pass


class MLModel:
"""
A machine learning model managed by Elasticsearch.
Expand All @@ -79,6 +80,7 @@ def __init__(
model_id: str
The unique identifier of the trained inference model in Elasticsearch.
"""
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._client: Elasticsearch = ensure_es_client(es_client)
self._model_id = model_id
self._trained_model_config_cache: Optional[Dict[str, Any]] = None
Expand Down
2 changes: 2 additions & 0 deletions eland/ml/pytorch/_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from eland.common import ensure_es_client
from eland.ml.pytorch.nlp_ml_model import NlpTrainedModelConfig
from warnings import warn

if TYPE_CHECKING:
from elasticsearch import Elasticsearch
Expand All @@ -57,6 +58,7 @@ def __init__(
es_client: Union[str, List[str], Tuple[str, ...], "Elasticsearch"],
model_id: str,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._client: Elasticsearch = ensure_es_client(es_client)
self.model_id = model_id

Expand Down
15 changes: 15 additions & 0 deletions eland/ml/pytorch/nlp_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import typing as t
from warnings import warn


class NlpTokenizationConfig:
Expand All @@ -30,6 +31,7 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self.name = configuration_type
self.with_special_tokens = with_special_tokens
self.max_sequence_length = max_sequence_length
Expand All @@ -56,6 +58,7 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(
configuration_type="roberta",
with_special_tokens=with_special_tokens,
Expand All @@ -78,6 +81,7 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(
configuration_type="bert",
with_special_tokens=with_special_tokens,
Expand All @@ -100,6 +104,7 @@ def __init__(
] = None,
span: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(
configuration_type="mpnet",
with_special_tokens=with_special_tokens,
Expand All @@ -112,6 +117,7 @@ def __init__(

class InferenceConfig:
def __init__(self, *, configuration_type: str):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self.name = configuration_type

def to_dict(self) -> t.Dict[str, t.Any]:
Expand All @@ -133,6 +139,7 @@ def __init__(
results_field: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="text_classification")
self.results_field = results_field
self.num_top_classes = num_top_classes
Expand All @@ -151,6 +158,7 @@ def __init__(
labels: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
hypothesis_template: t.Optional[str] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="zero_shot_classification")
self.tokenization = tokenization
self.hypothesis_template = hypothesis_template
Expand All @@ -168,6 +176,7 @@ def __init__(
results_field: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="fill_mask")
self.num_top_classes = num_top_classes
self.tokenization = tokenization
Expand All @@ -182,6 +191,7 @@ def __init__(
classification_labels: t.Union[t.List[str], t.Tuple[str, ...]],
results_field: t.Optional[str] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="ner")
self.tokenization = tokenization
self.classification_labels = classification_labels
Expand All @@ -195,6 +205,7 @@ def __init__(
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="pass_through")
self.tokenization = tokenization
self.results_field = results_field
Expand All @@ -210,6 +221,7 @@ def __init__(
question: t.Optional[str] = None,
num_top_classes: t.Optional[int] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="question_answering")
self.tokenization = tokenization
self.results_field = results_field
Expand All @@ -225,13 +237,15 @@ def __init__(
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(configuration_type="text_embedding")
self.tokenization = tokenization
self.results_field = results_field


class TrainedModelInput:
def __init__(self, *, field_names: t.List[str]):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self.field_names = field_names

def to_dict(self) -> t.Dict[str, t.Any]:
Expand All @@ -250,6 +264,7 @@ def __init__(
default_field_map: t.Optional[t.Mapping[str, str]] = None,
tags: t.Optional[t.Union[t.List[str], t.Tuple[str, ...]]] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self.tags = tags
self.default_field_map = default_field_map
self.description = description
Expand Down
2 changes: 2 additions & 0 deletions eland/ml/pytorch/traceable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch # type: ignore
from torch import nn
from warnings import warn

TracedModelTypes = Union[
torch.nn.Module,
Expand All @@ -37,6 +38,7 @@ def __init__(
self,
model: nn.Module,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._model = model

def quantize(self) -> None:
Expand Down
26 changes: 26 additions & 0 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
ZeroShotClassificationInferenceOptions,
)
from eland.ml.pytorch.traceable_model import TraceableModel
from warnings import warn

DEFAULT_OUTPUT_KEY = "sentence_embedding"
SUPPORTED_TASK_TYPES = {
Expand Down Expand Up @@ -112,6 +113,7 @@ class TaskTypeError(Exception):


def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]:
warn('func is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
if model_config.architectures is None:
if model_config.name_or_path.startswith("sentence-transformers/"):
return "text_embedding"
Expand Down Expand Up @@ -150,6 +152,7 @@ class _QuestionAnsweringWrapperModule(nn.Module): # type: ignore
"""

def __init__(self, model: PreTrainedModel):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__()
self._hf_model = model
self.config = model.config
Expand All @@ -174,6 +177,7 @@ def from_pretrained(model_id: str) -> Optional[Any]:

class _QuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(model=model)

def forward(
Expand Down Expand Up @@ -204,6 +208,7 @@ def forward(

class _TwoParameterQuestionAnsweringWrapper(_QuestionAnsweringWrapperModule):
def __init__(self, model: PreTrainedModel):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(model=model)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
Expand All @@ -225,6 +230,7 @@ class _DistilBertWrapper(nn.Module): # type: ignore
"""

def __init__(self, model: transformers.PreTrainedModel):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__()
self._model = model
self.config = model.config
Expand Down Expand Up @@ -256,6 +262,7 @@ class _SentenceTransformerWrapperModule(nn.Module): # type: ignore
"""

def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__()
self._hf_model = model
self._st_model = SentenceTransformer(model.config.name_or_path)
Expand Down Expand Up @@ -310,6 +317,7 @@ def _replace_transformer_layer(self) -> None:

class _SentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(model=model, output_key=output_key)

def forward(
Expand Down Expand Up @@ -337,6 +345,7 @@ def forward(

class _TwoParameterSentenceTransformerWrapper(_SentenceTransformerWrapperModule):
def __init__(self, model: PreTrainedModel, output_key: str = DEFAULT_OUTPUT_KEY):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__(model=model, output_key=output_key)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
Expand Down Expand Up @@ -365,6 +374,7 @@ def __init__(
self,
model: Union[transformers.DPRContextEncoder, transformers.DPRQuestionEncoder],
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super().__init__()
self._model = model
self.config = model.config
Expand Down Expand Up @@ -419,6 +429,7 @@ def __init__(
_DistilBertWrapper,
],
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
super(_TransformerTraceableModel, self).__init__(model=model)
self._tokenizer = tokenizer

Expand Down Expand Up @@ -462,6 +473,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableClassificationModel(_TransformerTraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] # type: ignore

Expand All @@ -471,6 +484,8 @@ def classification_labels(self) -> Optional[List[str]]:

class _TraceableFillMaskModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
"Who was Jim Henson?",
"[MASK] Henson was a puppeteer",
Expand All @@ -481,6 +496,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableNerModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
(
"Hugging Face Inc. is a company based in New York City. "
Expand All @@ -493,6 +510,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableTextClassificationModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
Expand All @@ -502,6 +521,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableTextEmbeddingModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
Expand All @@ -511,6 +532,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableZeroShotClassificationModel(_TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
"This is an example sentence.",
"This example is an example.",
Expand All @@ -521,6 +544,8 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class _TraceableQuestionAnsweringModel(_TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
warn('method is deprecated, this currently only supports ElasticSearch client', DeprecationWarning,
stacklevel=2)
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
Expand All @@ -531,6 +556,7 @@ def _prepare_inputs(self) -> transformers.BatchEncoding:

class TransformerModel:
def __init__(self, model_id: str, task_type: str, quantize: bool = False):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._model_id = model_id
self._task_type = task_type.replace("-", "_")

Expand Down
2 changes: 2 additions & 0 deletions eland/ml/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Optional, Sequence

from .._model_serializer import ModelSerializer
from warnings import warn


class ModelTransformer:
Expand All @@ -28,6 +29,7 @@ def __init__(
classification_labels: Optional[Sequence[str]] = None,
classification_weights: Optional[Sequence[float]] = None,
):
warn('class is deprecated, this currently only supports ElasticSearch client', DeprecationWarning, stacklevel=2)
self._feature_names = feature_names
self._model = model
self._classification_labels = classification_labels
Expand Down
Loading

0 comments on commit 35ccb3b

Please sign in to comment.