Skip to content

Commit 0a84cf3

Browse files
fix document embedding extraction
1 parent 4063cc8 commit 0a84cf3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flair/embeddings/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,9 @@ def _extract_document_embeddings(self, sentence_hidden_states, sentences):
628628
index_of_cls_token = 0 if self.initial_cls_token else -1
629629
embedding_all_document_layers = sentence_hidden_state[:, index_of_cls_token, :]
630630
elif self.cls_pooling == "mean":
631-
embedding_all_document_layers = sentence_hidden_state.mean(dim=2)
631+
embedding_all_document_layers = sentence_hidden_state.mean(dim=1)
632632
elif self.cls_pooling == "max":
633-
_, embedding_all_document_layers = sentence_hidden_state.max(dim=2)
633+
embedding_all_document_layers, _ = sentence_hidden_state.max(dim=1)
634634
else:
635635
raise ValueError(f"cls pooling method: `{self.cls_pooling}` is not implemented")
636636
if self.layer_mean:

0 commit comments

Comments
 (0)