generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
These models allow to find the closest text among a set of texts, provided a query. One is based on the RoBERTa model, while the other is based on a sentence-transformers model trained with a contrastive loss.
- Loading branch information
Showing
5 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from rich.progress import track | ||
from transformers import ( | ||
AutoModel, | ||
AutoModelForSequenceClassification, | ||
AutoTokenizer, | ||
PreTrainedModel, | ||
) | ||
|
||
|
||
class AbstractZeroshotClassifier(ABC): | ||
"""Abstract parent class of zero-shot classifiers. | ||
It instantiates the necessary transformers models. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_id: str, | ||
use_automodel_for_seq: bool = False, | ||
model_kwargs: dict = {}, | ||
tokenizer_kwargs: dict = {}, | ||
): | ||
"""Init method. | ||
Args: | ||
model_id (`str`): | ||
Name of the HF model to use. | ||
model_type (`bool`): | ||
Whether to use `AutoModelForSequenceClassification` or `AutoModel` (default). | ||
""" | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) | ||
|
||
if not use_automodel_for_seq: | ||
self.model = AutoModel.from_pretrained(model_id, **model_kwargs) | ||
else: | ||
self.model = AutoModelForSequenceClassification.from_pretrained( | ||
model_id, **model_kwargs | ||
) | ||
self.model.eval() | ||
|
||
def classify(self, query: str, keys: List[str]) -> Tuple[str, float, int]: | ||
"""Given a query, the most similar key is returned. | ||
Args: | ||
query (`str`): | ||
Short text to classify into one key. | ||
keys (`List[str]`): | ||
All possible classes, each one a short description. | ||
Returns: | ||
`str`: The most similar key. | ||
`float`: The score obtained. | ||
`int`: Index of the key. | ||
""" | ||
|
||
inputs = self.tokenizer(query, return_tensors="pt") | ||
target_embed = self._compute_embed(inputs) | ||
|
||
max_cosine_sim = -1 # min. possible cosine similarity | ||
most_sim_descr = "" | ||
selected_index = -1 | ||
for idx, descr in track(enumerate(keys), total=len(keys), description="Processing..."): | ||
inputs = self.tokenizer(descr, return_tensors="pt") | ||
descr_embed = self._compute_embed(inputs) | ||
|
||
cosine_sim = torch.sum( | ||
torch.nn.functional.normalize(target_embed) | ||
* torch.nn.functional.normalize(descr_embed) | ||
) | ||
if cosine_sim > max_cosine_sim: | ||
most_sim_descr = descr | ||
max_cosine_sim = cosine_sim | ||
selected_index = idx | ||
|
||
return most_sim_descr, max_cosine_sim, selected_index | ||
|
||
@abstractmethod | ||
def _compute_embed(self, inputs: dict) -> torch.Tensor: | ||
"""Computes the final embedding of the text given the tokenized inputs and the model. | ||
Args: | ||
inputs (`dict`): tokenized inputs to the model. | ||
Returns: | ||
`torch.Tensor`: embedding of the text. Shape (1, model_dim). | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from transformers import PreTrainedModel | ||
|
||
from openbb_chat.classifiers.abstract_zeroshot_classifier import ( | ||
AbstractZeroshotClassifier, | ||
) | ||
|
||
|
||
class RoBERTaZeroshotClassifier(AbstractZeroshotClassifier): | ||
"""Zero-shot classifier based on `sentence-transformers`.""" | ||
|
||
def __init__(self, model_id: str = "roberta-base", *args, **kwargs): | ||
"""Override __init__ to set default model_id.""" | ||
super().__init__(model_id, *args, **kwargs) | ||
|
||
def _compute_embed(self, inputs: dict) -> torch.Tensor: | ||
"""Override parent method to use RoBERTa pooler output ([CLS] token).""" | ||
return self.model(**inputs).pooler_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
from transformers import PreTrainedModel | ||
|
||
from openbb_chat.classifiers.abstract_zeroshot_classifier import ( | ||
AbstractZeroshotClassifier, | ||
) | ||
|
||
|
||
class STransformerZeroshotClassifier(AbstractZeroshotClassifier): | ||
"""Zero-shot classifier based on `sentence-transformers`.""" | ||
|
||
def __init__(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2", *args, **kwargs): | ||
"""Override __init__ to set default model_id.""" | ||
super().__init__(model_id, *args, **kwargs) | ||
|
||
def _compute_embed(self, inputs: dict) -> torch.Tensor: | ||
"""Override parent method to use `sentence-transformers` models in HF.""" | ||
return self._mean_pooling(self.model(**inputs), inputs["attention_mask"]) | ||
|
||
def _mean_pooling(self, model_output: object, attention_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Mean Pooling - Take attention mask into account for correct averaging. | ||
Args: | ||
model_output (`object`): output of `sentence-transformers` model in HF. | ||
attention_mask (`torch.Tensor`): attention mask denoting padding. | ||
Returns: | ||
`torch.Tensor`: final embedding computed from the model output. | ||
""" | ||
# Code from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 | ||
token_embeddings = model_output[ | ||
0 | ||
] # First element of model_output contains all token embeddings | ||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | ||
input_mask_expanded.sum(1), min=1e-9 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from transformers import AutoModel | ||
|
||
from openbb_chat.classifiers.roberta import RoBERTaZeroshotClassifier | ||
|
||
|
||
@patch("torch.sum") | ||
@patch.object(AutoModel, "from_pretrained") | ||
def test_classify(mocked_automodel_frompretrained, mocked_torch_sum): | ||
mocked_torch_sum.return_value = 1 | ||
|
||
roberta_zeroshot = RoBERTaZeroshotClassifier() | ||
key, score, idx = roberta_zeroshot.classify("Here is a dog", ["dog", "cat"]) | ||
|
||
mocked_automodel_frompretrained.assert_called_once_with("roberta-base") | ||
mocked_torch_sum.assert_called() | ||
assert key == "dog" | ||
assert score == 1 | ||
assert idx == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
from openbb_chat.classifiers.stransformer import STransformerZeroshotClassifier | ||
|
||
|
||
@patch("torch.clamp") | ||
@patch("torch.sum") | ||
@patch.object(AutoTokenizer, "from_pretrained") | ||
@patch.object(AutoModel, "from_pretrained") | ||
def test_classify( | ||
mocked_automodel_frompretrained, | ||
mocked_tokenizer_frompretrained, | ||
mocked_torch_sum, | ||
mocked_torch_clamp, | ||
): | ||
mocked_torch_sum.return_value = 1 | ||
|
||
stransformer_zeroshot = STransformerZeroshotClassifier() | ||
key, score, idx = stransformer_zeroshot.classify("Here is a dog", ["dog", "cat"]) | ||
|
||
mocked_automodel_frompretrained.assert_called_once_with( | ||
"sentence-transformers/all-MiniLM-L6-v2" | ||
) | ||
mocked_tokenizer_frompretrained.assert_called_once_with( | ||
"sentence-transformers/all-MiniLM-L6-v2" | ||
) | ||
mocked_torch_sum.assert_called() | ||
assert key == "dog" | ||
assert score == 1 | ||
assert idx == 0 |