File tree Expand file tree Collapse file tree 2 files changed +29
-2
lines changed
llama-index-integrations/embeddings/llama-index-embeddings-huggingface/llama_index/embeddings/huggingface Expand file tree Collapse file tree 2 files changed +29
-2
lines changed Original file line number Diff line number Diff line change @@ -160,12 +160,14 @@ def _embed(self, sentences: List[str]) -> List[List[float]]:
160
160
161
161
model_output = self ._model (** encoded_input )
162
162
163
+ context_layer : "torch.Tensor" = model_output [0 ]
163
164
if self .pooling == Pooling .CLS :
164
- context_layer : "torch.Tensor" = model_output [0 ]
165
165
embeddings = self .pooling .cls_pooling (context_layer )
166
+ elif self .pooling == Pooling .LAST :
167
+ embeddings = self .pooling .last_pooling (context_layer )
166
168
else :
167
169
embeddings = self ._mean_pooling (
168
- token_embeddings = model_output [ 0 ] ,
170
+ token_embeddings = context_layer ,
169
171
attention_mask = encoded_input ["attention_mask" ],
170
172
)
171
173
Original file line number Diff line number Diff line change @@ -12,10 +12,13 @@ class Pooling(str, Enum):
12
12
13
13
CLS = "cls"
14
14
MEAN = "mean"
15
+ LAST = "last" # last token pooling
15
16
16
17
def __call__ (self , array : np .ndarray ) -> np .ndarray :
17
18
if self == self .CLS :
18
19
return self .cls_pooling (array )
20
+ elif self == self .LAST :
21
+ return self .last_pooling (array )
19
22
return self .mean_pooling (array )
20
23
21
24
@classmethod
@@ -47,3 +50,25 @@ def mean_pooling(cls, array: np.ndarray) -> np.ndarray:
47
50
if len (array .shape ) == 2 :
48
51
return array .mean (axis = 0 )
49
52
raise NotImplementedError (f"Unhandled shape { array .shape } ." )
53
+
54
+ @classmethod
55
+ @overload
56
+ def last_pooling (cls , array : np .ndarray ) -> np .ndarray :
57
+ ...
58
+
59
+ @classmethod
60
+ @overload
61
+ # TODO: Remove this `type: ignore` after the false positive problem
62
+ # is addressed in mypy: https://github.com/python/mypy/issues/15683 .
63
+ def last_pooling (cls , array : "torch.Tensor" ) -> "torch.Tensor" : # type: ignore
64
+ ...
65
+
66
+ @classmethod
67
+ def last_pooling (
68
+ cls , array : "Union[np.ndarray, torch.Tensor]"
69
+ ) -> "Union[np.ndarray, torch.Tensor]" :
70
+ if len (array .shape ) == 3 :
71
+ return array [:, - 1 ]
72
+ if len (array .shape ) == 2 :
73
+ return array [- 1 ]
74
+ raise NotImplementedError (f"Unhandled shape { array .shape } ." )
You can’t perform that action at this time.
0 commit comments