Skip to content

Commit

Permalink
Add zeroshot classifiers
Browse files Browse the repository at this point in the history
These models allow to find the closest text among a set of
texts, provided a query. One is based on the RoBERTa model, while the
other is based on a sentence-transformers model trained with a
contrastive loss.
  • Loading branch information
Dedalo314 committed Jul 29, 2023
1 parent 6c48af7 commit 43e515e
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 0 deletions.
91 changes: 91 additions & 0 deletions openbb_chat/classifiers/abstract_zeroshot_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
from typing import List, Tuple

import torch
from rich.progress import track
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
)


class AbstractZeroshotClassifier(ABC):
"""Abstract parent class of zero-shot classifiers.
It instantiates the necessary transformers models.
"""

def __init__(
self,
model_id: str,
use_automodel_for_seq: bool = False,
model_kwargs: dict = {},
tokenizer_kwargs: dict = {},
):
"""Init method.
Args:
model_id (`str`):
Name of the HF model to use.
model_type (`bool`):
Whether to use `AutoModelForSequenceClassification` or `AutoModel` (default).
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs)

if not use_automodel_for_seq:
self.model = AutoModel.from_pretrained(model_id, **model_kwargs)
else:
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id, **model_kwargs
)
self.model.eval()

def classify(self, query: str, keys: List[str]) -> Tuple[str, float, int]:
"""Given a query, the most similar key is returned.
Args:
query (`str`):
Short text to classify into one key.
keys (`List[str]`):
All possible classes, each one a short description.
Returns:
`str`: The most similar key.
`float`: The score obtained.
`int`: Index of the key.
"""

inputs = self.tokenizer(query, return_tensors="pt")
target_embed = self._compute_embed(inputs)

max_cosine_sim = -1 # min. possible cosine similarity
most_sim_descr = ""
selected_index = -1
for idx, descr in track(enumerate(keys), total=len(keys), description="Processing..."):
inputs = self.tokenizer(descr, return_tensors="pt")
descr_embed = self._compute_embed(inputs)

cosine_sim = torch.sum(
torch.nn.functional.normalize(target_embed)
* torch.nn.functional.normalize(descr_embed)
)
if cosine_sim > max_cosine_sim:
most_sim_descr = descr
max_cosine_sim = cosine_sim
selected_index = idx

return most_sim_descr, max_cosine_sim, selected_index

@abstractmethod
def _compute_embed(self, inputs: dict) -> torch.Tensor:
"""Computes the final embedding of the text given the tokenized inputs and the model.
Args:
inputs (`dict`): tokenized inputs to the model.
Returns:
`torch.Tensor`: embedding of the text. Shape (1, model_dim).
"""
raise NotImplementedError
18 changes: 18 additions & 0 deletions openbb_chat/classifiers/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from transformers import PreTrainedModel

from openbb_chat.classifiers.abstract_zeroshot_classifier import (
AbstractZeroshotClassifier,
)


class RoBERTaZeroshotClassifier(AbstractZeroshotClassifier):
"""Zero-shot classifier based on `sentence-transformers`."""

def __init__(self, model_id: str = "roberta-base", *args, **kwargs):
"""Override __init__ to set default model_id."""
super().__init__(model_id, *args, **kwargs)

def _compute_embed(self, inputs: dict) -> torch.Tensor:
"""Override parent method to use RoBERTa pooler output ([CLS] token)."""
return self.model(**inputs).pooler_output
38 changes: 38 additions & 0 deletions openbb_chat/classifiers/stransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from transformers import PreTrainedModel

from openbb_chat.classifiers.abstract_zeroshot_classifier import (
AbstractZeroshotClassifier,
)


class STransformerZeroshotClassifier(AbstractZeroshotClassifier):
"""Zero-shot classifier based on `sentence-transformers`."""

def __init__(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2", *args, **kwargs):
"""Override __init__ to set default model_id."""
super().__init__(model_id, *args, **kwargs)

def _compute_embed(self, inputs: dict) -> torch.Tensor:
"""Override parent method to use `sentence-transformers` models in HF."""
return self._mean_pooling(self.model(**inputs), inputs["attention_mask"])

def _mean_pooling(self, model_output: object, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Mean Pooling - Take attention mask into account for correct averaging.
Args:
model_output (`object`): output of `sentence-transformers` model in HF.
attention_mask (`torch.Tensor`): attention mask denoting padding.
Returns:
`torch.Tensor`: final embedding computed from the model output.
"""
# Code from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
21 changes: 21 additions & 0 deletions tests/classifiers/test_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import patch

import pytest
from transformers import AutoModel

from openbb_chat.classifiers.roberta import RoBERTaZeroshotClassifier


@patch("torch.sum")
@patch.object(AutoModel, "from_pretrained")
def test_classify(mocked_automodel_frompretrained, mocked_torch_sum):
mocked_torch_sum.return_value = 1

roberta_zeroshot = RoBERTaZeroshotClassifier()
key, score, idx = roberta_zeroshot.classify("Here is a dog", ["dog", "cat"])

mocked_automodel_frompretrained.assert_called_once_with("roberta-base")
mocked_torch_sum.assert_called()
assert key == "dog"
assert score == 1
assert idx == 0
33 changes: 33 additions & 0 deletions tests/classifiers/test_stransformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import patch

import pytest
from transformers import AutoModel, AutoTokenizer

from openbb_chat.classifiers.stransformer import STransformerZeroshotClassifier


@patch("torch.clamp")
@patch("torch.sum")
@patch.object(AutoTokenizer, "from_pretrained")
@patch.object(AutoModel, "from_pretrained")
def test_classify(
mocked_automodel_frompretrained,
mocked_tokenizer_frompretrained,
mocked_torch_sum,
mocked_torch_clamp,
):
mocked_torch_sum.return_value = 1

stransformer_zeroshot = STransformerZeroshotClassifier()
key, score, idx = stransformer_zeroshot.classify("Here is a dog", ["dog", "cat"])

mocked_automodel_frompretrained.assert_called_once_with(
"sentence-transformers/all-MiniLM-L6-v2"
)
mocked_tokenizer_frompretrained.assert_called_once_with(
"sentence-transformers/all-MiniLM-L6-v2"
)
mocked_torch_sum.assert_called()
assert key == "dog"
assert score == 1
assert idx == 0

0 comments on commit 43e515e

Please sign in to comment.