diff --git a/openbb_chat/classifiers/abstract_zeroshot_classifier.py b/openbb_chat/classifiers/abstract_zeroshot_classifier.py new file mode 100644 index 0000000..8492bfe --- /dev/null +++ b/openbb_chat/classifiers/abstract_zeroshot_classifier.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple + +import torch +from rich.progress import track +from transformers import ( + AutoModel, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, +) + + +class AbstractZeroshotClassifier(ABC): + """Abstract parent class of zero-shot classifiers. + + It instantiates the necessary transformers models. + """ + + def __init__( + self, + model_id: str, + use_automodel_for_seq: bool = False, + model_kwargs: dict = {}, + tokenizer_kwargs: dict = {}, + ): + """Init method. + + Args: + model_id (`str`): + Name of the HF model to use. + model_type (`bool`): + Whether to use `AutoModelForSequenceClassification` or `AutoModel` (default). + """ + self.tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) + + if not use_automodel_for_seq: + self.model = AutoModel.from_pretrained(model_id, **model_kwargs) + else: + self.model = AutoModelForSequenceClassification.from_pretrained( + model_id, **model_kwargs + ) + self.model.eval() + + def classify(self, query: str, keys: List[str]) -> Tuple[str, float, int]: + """Given a query, the most similar key is returned. + + Args: + query (`str`): + Short text to classify into one key. + keys (`List[str]`): + All possible classes, each one a short description. + + Returns: + `str`: The most similar key. + `float`: The score obtained. + `int`: Index of the key. + """ + + inputs = self.tokenizer(query, return_tensors="pt") + target_embed = self._compute_embed(inputs) + + max_cosine_sim = -1 # min. possible cosine similarity + most_sim_descr = "" + selected_index = -1 + for idx, descr in track(enumerate(keys), total=len(keys), description="Processing..."): + inputs = self.tokenizer(descr, return_tensors="pt") + descr_embed = self._compute_embed(inputs) + + cosine_sim = torch.sum( + torch.nn.functional.normalize(target_embed) + * torch.nn.functional.normalize(descr_embed) + ) + if cosine_sim > max_cosine_sim: + most_sim_descr = descr + max_cosine_sim = cosine_sim + selected_index = idx + + return most_sim_descr, max_cosine_sim, selected_index + + @abstractmethod + def _compute_embed(self, inputs: dict) -> torch.Tensor: + """Computes the final embedding of the text given the tokenized inputs and the model. + + Args: + inputs (`dict`): tokenized inputs to the model. + + Returns: + `torch.Tensor`: embedding of the text. Shape (1, model_dim). + """ + raise NotImplementedError diff --git a/openbb_chat/classifiers/roberta.py b/openbb_chat/classifiers/roberta.py new file mode 100644 index 0000000..0c88d9c --- /dev/null +++ b/openbb_chat/classifiers/roberta.py @@ -0,0 +1,18 @@ +import torch +from transformers import PreTrainedModel + +from openbb_chat.classifiers.abstract_zeroshot_classifier import ( + AbstractZeroshotClassifier, +) + + +class RoBERTaZeroshotClassifier(AbstractZeroshotClassifier): + """Zero-shot classifier based on `sentence-transformers`.""" + + def __init__(self, model_id: str = "roberta-base", *args, **kwargs): + """Override __init__ to set default model_id.""" + super().__init__(model_id, *args, **kwargs) + + def _compute_embed(self, inputs: dict) -> torch.Tensor: + """Override parent method to use RoBERTa pooler output ([CLS] token).""" + return self.model(**inputs).pooler_output diff --git a/openbb_chat/classifiers/stransformer.py b/openbb_chat/classifiers/stransformer.py new file mode 100644 index 0000000..52f63a4 --- /dev/null +++ b/openbb_chat/classifiers/stransformer.py @@ -0,0 +1,38 @@ +import torch +from transformers import PreTrainedModel + +from openbb_chat.classifiers.abstract_zeroshot_classifier import ( + AbstractZeroshotClassifier, +) + + +class STransformerZeroshotClassifier(AbstractZeroshotClassifier): + """Zero-shot classifier based on `sentence-transformers`.""" + + def __init__(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2", *args, **kwargs): + """Override __init__ to set default model_id.""" + super().__init__(model_id, *args, **kwargs) + + def _compute_embed(self, inputs: dict) -> torch.Tensor: + """Override parent method to use `sentence-transformers` models in HF.""" + return self._mean_pooling(self.model(**inputs), inputs["attention_mask"]) + + def _mean_pooling(self, model_output: object, attention_mask: torch.Tensor) -> torch.Tensor: + """ + Mean Pooling - Take attention mask into account for correct averaging. + + Args: + model_output (`object`): output of `sentence-transformers` model in HF. + attention_mask (`torch.Tensor`): attention mask denoting padding. + + Returns: + `torch.Tensor`: final embedding computed from the model output. + """ + # Code from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 + token_embeddings = model_output[ + 0 + ] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( + input_mask_expanded.sum(1), min=1e-9 + ) diff --git a/tests/classifiers/test_roberta.py b/tests/classifiers/test_roberta.py new file mode 100644 index 0000000..47bfdb2 --- /dev/null +++ b/tests/classifiers/test_roberta.py @@ -0,0 +1,21 @@ +from unittest.mock import patch + +import pytest +from transformers import AutoModel + +from openbb_chat.classifiers.roberta import RoBERTaZeroshotClassifier + + +@patch("torch.sum") +@patch.object(AutoModel, "from_pretrained") +def test_classify(mocked_automodel_frompretrained, mocked_torch_sum): + mocked_torch_sum.return_value = 1 + + roberta_zeroshot = RoBERTaZeroshotClassifier() + key, score, idx = roberta_zeroshot.classify("Here is a dog", ["dog", "cat"]) + + mocked_automodel_frompretrained.assert_called_once_with("roberta-base") + mocked_torch_sum.assert_called() + assert key == "dog" + assert score == 1 + assert idx == 0 diff --git a/tests/classifiers/test_stransformers.py b/tests/classifiers/test_stransformers.py new file mode 100644 index 0000000..bbd585b --- /dev/null +++ b/tests/classifiers/test_stransformers.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + +import pytest +from transformers import AutoModel, AutoTokenizer + +from openbb_chat.classifiers.stransformer import STransformerZeroshotClassifier + + +@patch("torch.clamp") +@patch("torch.sum") +@patch.object(AutoTokenizer, "from_pretrained") +@patch.object(AutoModel, "from_pretrained") +def test_classify( + mocked_automodel_frompretrained, + mocked_tokenizer_frompretrained, + mocked_torch_sum, + mocked_torch_clamp, +): + mocked_torch_sum.return_value = 1 + + stransformer_zeroshot = STransformerZeroshotClassifier() + key, score, idx = stransformer_zeroshot.classify("Here is a dog", ["dog", "cat"]) + + mocked_automodel_frompretrained.assert_called_once_with( + "sentence-transformers/all-MiniLM-L6-v2" + ) + mocked_tokenizer_frompretrained.assert_called_once_with( + "sentence-transformers/all-MiniLM-L6-v2" + ) + mocked_torch_sum.assert_called() + assert key == "dog" + assert score == 1 + assert idx == 0