Skip to content

Commit 6024956

Browse files
marib00rootmarib00
authored
Last token pooling for Huggingface models like SFR-Embedding-Mistral (#11373)
* Added last token pooling for Huggingface models like Salesforce/SFR-Embedding-Mistral * fixed whitespace * Added overloaded method signatures --------- Co-authored-by: root <root@maid-beast.staff.bournemouth.ac.uk> Co-authored-by: marib00 <newborn09current@icloud.com>
1 parent cc8e1ee commit 6024956

File tree

2 files changed

+29
-2
lines changed
  • llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface

2 files changed

+29
-2
lines changed

llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,14 @@ def _embed(self, sentences: List[str]) -> List[List[float]]:
160160

161161
model_output = self._model(**encoded_input)
162162

163+
context_layer: "torch.Tensor" = model_output[0]
163164
if self.pooling == Pooling.CLS:
164-
context_layer: "torch.Tensor" = model_output[0]
165165
embeddings = self.pooling.cls_pooling(context_layer)
166+
elif self.pooling == Pooling.LAST:
167+
embeddings = self.pooling.last_pooling(context_layer)
166168
else:
167169
embeddings = self._mean_pooling(
168-
token_embeddings=model_output[0],
170+
token_embeddings=context_layer,
169171
attention_mask=encoded_input["attention_mask"],
170172
)
171173

llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface/pooling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ class Pooling(str, Enum):
1212

1313
CLS = "cls"
1414
MEAN = "mean"
15+
LAST = "last" # last token pooling
1516

1617
def __call__(self, array: np.ndarray) -> np.ndarray:
1718
if self == self.CLS:
1819
return self.cls_pooling(array)
20+
elif self == self.LAST:
21+
return self.last_pooling(array)
1922
return self.mean_pooling(array)
2023

2124
@classmethod
@@ -47,3 +50,25 @@ def mean_pooling(cls, array: np.ndarray) -> np.ndarray:
4750
if len(array.shape) == 2:
4851
return array.mean(axis=0)
4952
raise NotImplementedError(f"Unhandled shape {array.shape}.")
53+
54+
@classmethod
55+
@overload
56+
def last_pooling(cls, array: np.ndarray) -> np.ndarray:
57+
...
58+
59+
@classmethod
60+
@overload
61+
# TODO: Remove this `type: ignore` after the false positive problem
62+
# is addressed in mypy: https://github.com/python/mypy/issues/15683 .
63+
def last_pooling(cls, array: "torch.Tensor") -> "torch.Tensor": # type: ignore
64+
...
65+
66+
@classmethod
67+
def last_pooling(
68+
cls, array: "Union[np.ndarray, torch.Tensor]"
69+
) -> "Union[np.ndarray, torch.Tensor]":
70+
if len(array.shape) == 3:
71+
return array[:, -1]
72+
if len(array.shape) == 2:
73+
return array[-1]
74+
raise NotImplementedError(f"Unhandled shape {array.shape}.")

0 commit comments

Comments
 (0)