1
1
from os .path import join
2
2
3
3
import torch
4
- from commode_utils .callback import PrintEpochResultCallback , ModelCheckpointWithUpload
4
+ from commode_utils .callbacks import ModelCheckpointWithUploadCallback , PrintEpochResultCallback
5
5
from omegaconf import DictConfig , OmegaConf
6
6
from pytorch_lightning import seed_everything , Trainer , LightningModule , LightningDataModule
7
- from pytorch_lightning .callbacks import EarlyStopping , LearningRateMonitor , RichProgressBar
7
+ from pytorch_lightning .callbacks import EarlyStopping , LearningRateMonitor , TQDMProgressBar
8
8
from pytorch_lightning .loggers import WandbLogger
9
9
10
10
@@ -22,7 +22,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
22
22
)
23
23
24
24
# define model checkpoint callback
25
- checkpoint_callback = ModelCheckpointWithUpload (
25
+ checkpoint_callback = ModelCheckpointWithUploadCallback (
26
26
dirpath = join (wandb_logger .experiment .dir , "checkpoints" ),
27
27
filename = "{epoch:02d}-val_loss={val/loss:.4f}" ,
28
28
monitor = "val/loss" ,
@@ -39,7 +39,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
39
39
# define learning rate logger
40
40
lr_logger = LearningRateMonitor ("step" )
41
41
# define progress bar callback
42
- progress_bar = RichProgressBar ( refresh_rate_per_second = config .progress_bar_refresh_rate )
42
+ progress_bar = TQDMProgressBar ( refresh_rate = config .progress_bar_refresh_rate )
43
43
trainer = Trainer (
44
44
max_epochs = params .n_epochs ,
45
45
gradient_clip_val = params .clip_norm ,
0 commit comments