Skip to content

Commit b277662

Browse files
committed
Update to last version of commode-utils
1 parent 76cca38 commit b277662

File tree

3 files changed

+6
-8
lines changed

3 files changed

+6
-8
lines changed

code2seq/utils/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from os.path import join
22

33
import torch
4-
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
4+
from commode_utils.callbacks import ModelCheckpointWithUploadCallback, PrintEpochResultCallback
55
from omegaconf import DictConfig, OmegaConf
66
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
7-
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar
7+
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, TQDMProgressBar
88
from pytorch_lightning.loggers import WandbLogger
99

1010

@@ -22,7 +22,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
2222
)
2323

2424
# define model checkpoint callback
25-
checkpoint_callback = ModelCheckpointWithUpload(
25+
checkpoint_callback = ModelCheckpointWithUploadCallback(
2626
dirpath=join(wandb_logger.experiment.dir, "checkpoints"),
2727
filename="{epoch:02d}-val_loss={val/loss:.4f}",
2828
monitor="val/loss",
@@ -39,7 +39,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
3939
# define learning rate logger
4040
lr_logger = LearningRateMonitor("step")
4141
# define progress bar callback
42-
progress_bar = RichProgressBar(refresh_rate_per_second=config.progress_bar_refresh_rate)
42+
progress_bar = TQDMProgressBar(refresh_rate=config.progress_bar_refresh_rate)
4343
trainer = Trainer(
4444
max_epochs=params.n_epochs,
4545
gradient_clip_val=params.clip_norm,

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@ torchmetrics==0.6.0
44
tqdm==4.62.3
55
wandb==0.12.6
66
omegaconf==2.1.1
7-
commode-utils==0.4.0
8-
rich==10.13.0
7+
commode-utils==0.4.1

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
"pytorch-lightning~=1.5.0",
1111
"wandb~=0.12.0",
1212
"omegaconf~=2.1.1",
13-
"commode-utils>=0.4.0",
14-
"rich>=10.0.0",
13+
"commode-utils>=0.4.1",
1514
]
1615

1716
setup_args = dict(

0 commit comments

Comments
 (0)