-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
56 lines (44 loc) · 1.52 KB
/
train.py
File metadata and controls
56 lines (44 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import yaml
import torch
from loguru import logger
from dataset import DataModule
from lightning_module import LightningModule
from utils.util import increment_path, setup_logger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
def main(config_path: str = 'config.yaml'):
config = yaml.safe_load(open(config_path))
# logging
exp_dir = increment_path(config['path']['experiment'])
setup_logger(os.path.join(exp_dir, 'train.log'))
logger.info(f"Experiment path: {exp_dir}")
# seed
seed_everything(config['training']['seed'])
# dataset
data_module = DataModule(config)
model = LightningModule(config, exp_dir)
# early stopping
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=config['training']['max_patience'],
mode='min',
verbose=True
)
trainer = Trainer(
logger=False,
max_epochs=config['training']['num_epochs'],
gradient_clip_val=config['training']['max_grad_norm'],
accumulate_grad_batches=config['training']['grad_acc_steps'],
log_every_n_steps=10,
val_check_interval=1.0,
default_root_dir=exp_dir,
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
devices=1,
enable_progress_bar=False,
callbacks=[early_stop_callback],
)
trainer.fit(model, datamodule=data_module)
logger.info("Training complete")
if __name__ == '__main__':
main()