Skip to content

Use rich progress bar instead of tqdm #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions code2seq/code2class_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_code2class(config: DictConfig):

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data, is_class=True)
data_module.prepare_data()
data_module.setup()

# Load model
code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary)
Expand Down
2 changes: 0 additions & 2 deletions code2seq/code2seq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_code2seq(config: DictConfig):

# Load data module
data_module = PathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)
Expand Down
17 changes: 9 additions & 8 deletions code2seq/data/path_context_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class PathContextDataModule(LightningDataModule):
_val = "val"
_test = "test"

_vocabulary: Optional[Vocabulary] = None

def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False):
super().__init__()
self._config = config
self._data_dir = data_dir
self._name = basename(data_dir)
self._is_class = is_class

self._vocabulary = self.setup_vocabulary()

@property
def vocabulary(self) -> Vocabulary:
if self._vocabulary is None:
Expand All @@ -41,14 +41,12 @@ def prepare_data(self):
raise ValueError(f"Config doesn't contain url for, can't download it automatically")
download_dataset(self._config.url, self._data_dir, self._name)

def setup(self, stage: Optional[str] = None):
if not exists(join(self._data_dir, Vocabulary.vocab_filename)):
def setup_vocabulary(self) -> Vocabulary:
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
if not exists(vocabulary_path):
print("Can't find vocabulary, collect it from train holdout")
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary)
vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename)
self._vocabulary = Vocabulary(
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class
)
return Vocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class)

@staticmethod
def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext:
Expand Down Expand Up @@ -88,6 +86,9 @@ def val_dataloader(self, *args, **kwargs) -> DataLoader:
def test_dataloader(self, *args, **kwargs) -> DataLoader:
return self._shared_dataloader(self._test)

def predict_dataloader(self, *args, **kwargs) -> DataLoader:
return self.test_dataloader(*args, **kwargs)

def transfer_batch_to_device(
self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int
) -> BatchedLabeledPathContext:
Expand Down
6 changes: 3 additions & 3 deletions code2seq/data/typed_path_context_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class TypedPathContextDataModule(PathContextDataModule):
_vocabulary: Optional[TypedVocabulary] = None
_vocabulary: TypedVocabulary

def __init__(self, data_dir: str, config: DictConfig):
super().__init__(data_dir, config)
Expand All @@ -27,12 +27,12 @@ def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathC
raise RuntimeError(f"Setup vocabulary before creating data loaders")
return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context)

def setup(self, stage: Optional[str] = None):
def setup_vocabulary(self) -> TypedVocabulary:
if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)):
print("Can't find vocabulary, collect it from train holdout")
build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary)
vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename)
self._vocabulary = TypedVocabulary(
return TypedVocabulary(
vocabulary_path, self._config.labels_count, self._config.tokens_count, self._config.types_count
)

Expand Down
22 changes: 15 additions & 7 deletions code2seq/model/code2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from commode_utils.losses import SequenceCrossEntropyLoss
from commode_utils.metrics import SequentialF1Score, ClassificationMetrics
from commode_utils.metrics.chrF import ChrF
from commode_utils.modules import LSTMDecoderStep, Decoder
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
Expand Down Expand Up @@ -41,6 +42,10 @@ def __init__(
f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx)
for holdout in ["train", "val", "test"]
}
id2label = {v: k for k, v in vocabulary.label_to_id.items()}
metrics.update(
{f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]}
)
self.__metrics = MetricCollection(metrics)

self._encoder = self._get_encoder(model_config)
Expand Down Expand Up @@ -102,18 +107,18 @@ def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict:
target_sequence = batch.labels if step == "train" else None
# [seq length; batch size; vocab size]
logits, _ = self.logits_from_batch(batch, target_sequence)
loss = self.__loss(logits[1:], batch.labels[1:])
result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])}

with torch.no_grad():
prediction = logits.argmax(-1)
metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels)
result.update(
{f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall}
)
if step != "train":
result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels)

return {
f"{step}/loss": loss,
f"{step}/f1": metric.f1_score,
f"{step}/precision": metric.precision,
f"{step}/recall": metric.recall,
}
return result

def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore
result = self._shared_step(batch, "train")
Expand Down Expand Up @@ -143,6 +148,9 @@ def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str):
f"{step}/recall": metric.recall,
}
self.__metrics[f"{step}_f1"].reset()
if step != "train":
log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute()
self.__metrics[f"{step}_chrf"].reset()
self.log_dict(log, on_step=False, on_epoch=True)

def training_epoch_end(self, step_outputs: EPOCH_OUTPUT):
Expand Down
2 changes: 0 additions & 2 deletions code2seq/typed_code2seq_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def train_typed_code2seq(config: DictConfig):

# Load data module
data_module = TypedPathContextDataModule(config.data_folder, config.data)
data_module.prepare_data()
data_module.setup()

# Load model
typed_code2seq = TypedCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing)
Expand Down
2 changes: 1 addition & 1 deletion code2seq/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def filter_warnings():
# "The dataloader does not have many workers which may be a bottleneck."
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.utilities.distributed", lineno=50)
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=105)
filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=110)
# "Please also save or load the state of the optimizer when saving or loading the scheduler."
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=216) # save
filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=234) # load
18 changes: 8 additions & 10 deletions code2seq/utils/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from os.path import join

import torch
from commode_utils.callback import PrintEpochResultCallback, ModelCheckpointWithUpload
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, RichProgressBar
from pytorch_lightning.loggers import WandbLogger


Expand All @@ -21,7 +23,7 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict

# define model checkpoint callback
checkpoint_callback = ModelCheckpointWithUpload(
dirpath=wandb_logger.experiment.dir,
dirpath=join(wandb_logger.experiment.dir, "checkpoints"),
filename="{epoch:02d}-val_loss={val/loss:.4f}",
monitor="val/loss",
every_n_epochs=params.save_every_epoch,
Expand All @@ -36,6 +38,8 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
gpu = 1 if torch.cuda.is_available() else None
# define learning rate logger
lr_logger = LearningRateMonitor("step")
# define progress bar callback
progress_bar = RichProgressBar(refresh_rate_per_second=config.progress_bar_refresh_rate)
trainer = Trainer(
max_epochs=params.n_epochs,
gradient_clip_val=params.clip_norm,
Expand All @@ -44,15 +48,9 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
log_every_n_steps=params.log_every_n_steps,
logger=wandb_logger,
gpus=gpu,
progress_bar_refresh_rate=config.progress_bar_refresh_rate,
callbacks=[
lr_logger,
early_stopping_callback,
checkpoint_callback,
print_epoch_result_callback,
],
callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar],
resume_from_checkpoint=config.get("checkpoint", None),
)

trainer.fit(model=model, datamodule=data_module)
trainer.test()
trainer.test(datamodule=data_module, ckpt_path="best")
2 changes: 1 addition & 1 deletion config/code2seq-java-med.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data:
random_context: true

batch_size: 512
test_batch_size: 768
test_batch_size: 512

model:
# Encoder
Expand Down
1 change: 0 additions & 1 deletion config/code2seq-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ data_folder: ../data/code2seq/java-test
checkpoint: null

seed: 7
# Training in notebooks (e.g. Google Colab) may crash with too small value
progress_bar_refresh_rate: 1
print_config: true

Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
torch==1.10.0
pytorch-lightning==1.4.9
torchmetrics==0.5.1
pytorch-lightning==1.5.1
torchmetrics==0.6.0
tqdm==4.62.3
wandb==0.12.6
omegaconf==2.1.1
commode-utils==0.3.12
commode-utils==0.4.0
10 changes: 4 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from setuptools import setup, find_packages

VERSION = "1.1.1"
VERSION = "1.2.0"

with open("README.md") as readme_file:
readme = readme_file.read()

install_requires = [
"torch>=1.9.0",
"pytorch-lightning~=1.4.2",
"torchmetrics~=0.5.0",
"tqdm~=4.62.1",
"torch>=1.10.0",
"pytorch-lightning~=1.5.0",
"wandb~=0.12.0",
"omegaconf~=2.1.1",
"commode-utils>=0.3.8",
"commode-utils>=0.4.0",
]

setup_args = dict(
Expand Down