forked from stanford-futuredata/ColBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
36 lines (24 loc) · 1.23 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from colbert.infra.run import Run
from colbert.infra.launcher import Launcher
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.training.training import train
class Trainer:
def __init__(self, triples, queries, collection, config=None):
self.config = ColBERTConfig.from_existing(config, Run().config)
self.triples = triples
self.queries = queries
self.collection = collection
def configure(self, **kw_args):
self.config.configure(**kw_args)
def train(self, checkpoint='bert-base-uncased'):
"""
Note that config.checkpoint is ignored. Only the supplied checkpoint here is used.
"""
# Resources don't come from the config object. They come from the input parameters.
# TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction.
self.configure(triples=self.triples, queries=self.queries, collection=self.collection)
self.configure(checkpoint=checkpoint)
launcher = Launcher(train)
self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection)
def best_checkpoint_path(self):
return self._best_checkpoint_path