Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the smart_batching_collate function #1852

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import stat
from collections import OrderedDict
import warnings
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal
from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal, TYPE_CHECKING
import numpy as np
from numpy import ndarray
import transformers
Expand All @@ -31,6 +31,10 @@
logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from sentence_transformers.readers import InputExample


def get_device_name() -> Literal["mps", "cuda", "cpu"]:
"""
Returns the name of the device where this module is running on.
Expand Down Expand Up @@ -564,36 +568,21 @@ def on_rm_error(func, path, exc_info):

return push_return

def smart_batching_collate(self, batch):
def smart_batching_collate(self, batch: List["InputExample"]) -> Tuple[List[Dict[str, Tensor]], Tensor]:
"""
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
Here, batch is a list of tuples: [(tokens, label), ...]
Here, batch is a list of InputExample instances: [InputExample(...), ...]

:param batch:
a batch from a SmartBatchingDataset
:return:
a batch of tensors for the model
"""
num_texts = len(batch[0].texts)
texts = [[] for _ in range(num_texts)]
labels = []

for example in batch:
for idx, text in enumerate(example.texts):
texts[idx].append(text)

labels.append(example.label)

labels = torch.tensor(labels)

sentence_features = []
for idx in range(num_texts):
tokenized = self.tokenize(texts[idx])
sentence_features.append(tokenized)

texts = [example.texts for example in batch]
sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
labels = torch.tensor([example.label for example in batch])
return sentence_features, labels


def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
Expand Down