-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcustom_pooling.py
47 lines (38 loc) · 1.98 KB
/
custom_pooling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import os
from typing import Dict
import torch
from sentence_transformers.models import Pooling
from torch import Tensor
class CustomPooling(Pooling):
def __init__(self, word_embedding_dimension, pooling_mode=None,
pooling_mode_max_tokens: bool = False,
pooling_mode_mean_tokens: bool = True):
assert pooling_mode in {'mean', 'max'} or pooling_mode_mean_tokens or pooling_mode_max_tokens
super().__init__(word_embedding_dimension, pooling_mode=pooling_mode,
pooling_mode_max_tokens=pooling_mode_max_tokens,
pooling_mode_mean_tokens=pooling_mode_mean_tokens)
self.config_keys = ['word_embedding_dimension', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens']
self.pooling_output_dimension = word_embedding_dimension
def forward(self, features: Dict[str, Tensor]):
token_embeddings = features['token_embeddings']
pooling_mask = features['pooling_mask']
## Pooling strategy
if self.pooling_mode_max_tokens:
input_mask_expanded = pooling_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
max_over_time = torch.max(token_embeddings, 1)[0]
output_vector = max_over_time
else:
input_mask_expanded = pooling_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
output_vector = sum_embeddings / sum_mask
features.update({'sentence_embedding': output_vector})
return features
@staticmethod
def load(input_path):
with open(os.path.join(input_path, 'config.json')) as fIn:
config = json.load(fIn)
return CustomPooling(**config)