forked from jaywonchung/BERT4Rec-VAE-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert.py
40 lines (30 loc) · 1.17 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from .base import AbstractTrainer
from .utils import recalls_and_ndcgs_for_ks
import torch.nn as nn
class BERTTrainer(AbstractTrainer):
def __init__(self, args, model, train_loader, val_loader, test_loader, export_root):
super().__init__(args, model, train_loader, val_loader, test_loader, export_root)
self.ce = nn.CrossEntropyLoss(ignore_index=0)
@classmethod
def code(cls):
return 'bert'
def add_extra_loggers(self):
pass
def log_extra_train_info(self, log_data):
pass
def log_extra_val_info(self, log_data):
pass
def calculate_loss(self, batch):
seqs, labels = batch
logits = self.model(seqs) # B x T x V
logits = logits.view(-1, logits.size(-1)) # (B*T) x V
labels = labels.view(-1) # B*T
loss = self.ce(logits, labels)
return loss
def calculate_metrics(self, batch):
seqs, candidates, labels = batch
scores = self.model(seqs) # B x T x V
scores = scores[:, -1, :] # B x V
scores = scores.gather(1, candidates) # B x C
metrics = recalls_and_ndcgs_for_ks(scores, labels, self.metric_ks)
return metrics