-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlistMLE.py
35 lines (24 loc) · 1.39 KB
/
listMLE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
def listMLE(y_pred, y_true, eps=1e-10, padded_value_indicator=-1):
"""
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
y_pred_shuffled = y_pred[:, random_indices]
y_true_shuffled = y_true[:, random_indices]
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
mask = y_true_sorted == padded_value_indicator
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
preds_sorted_by_true[mask] = float("-inf")
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0
return torch.mean(torch.sum(observation_loss, dim=1))