-
Notifications
You must be signed in to change notification settings - Fork 28
/
trainer.py
108 lines (84 loc) · 3.6 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from typing import Dict, List, Tuple, Optional, Any, Union
from transformers.trainer import Trainer, nested_detach
from transformers.trainer_utils import PredictionOutput, EvalPrediction
import torch
from torch import nn
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
import logging
logger = logging.getLogger(__name__)
class COILTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
self.model.save_pretrained(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def _prepare_inputs(self, inputs):
prepared = {}
for k, v in inputs.items():
prepared[k] = {}
for sk, sv in v.items():
if isinstance(sv, torch.Tensor):
prepared[k][sk] = sv.to(self.args.device)
return prepared
def create_optimizer_and_scheduler(self, num_training_steps: int):
if self.args.warmup_ratio > 0:
self.args.warmup_steps = num_training_steps * self.args.warmup_ratio
super().create_optimizer_and_scheduler(num_training_steps)
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = self._get_train_sampler()
return DataLoader(
self.train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=True,
num_workers=self.args.dataloader_num_workers,
)
def prediction_step(
self,
model: nn.Module,
inputs: Tuple[Dict[str, Union[torch.Tensor, Any]]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
qry, doc = self._prepare_inputs(inputs)[:2]
if ignore_keys is None:
raise NotImplementedError
with torch.no_grad():
if self.args.fp16:
with autocast():
outputs = model(qry, doc)
else:
outputs = model(qry, doc)
loss = outputs.loss
logits = outputs.scores
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
labels = None
return (loss, logits, labels)
def prediction_loop(
self,
*args,
**kwargs
) -> PredictionOutput:
pred_outs = super().prediction_loop(*args, **kwargs)
preds, label_ids, metrics = pred_outs.predictions, pred_outs.label_ids, pred_outs.metrics
preds = preds.squeeze()
if self.compute_metrics is not None:
metrics_no_label = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics_no_label = {}
for key in list(metrics_no_label.keys()):
if not key.startswith("eval_"):
metrics_no_label[f"eval_{key}"] = metrics_no_label.pop(key)
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics={**metrics, **metrics_no_label})