diff --git a/README.md b/README.md index 17300730..15c4ac9a 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ PyTorch deep learning project made easy. * [Additional logging](#additional-logging) * [Validation data](#validation-data) * [Checkpoints](#checkpoints) - * [TensorboardX Visualization](#tensorboardx-visualization) + * [Tensorboard Visualization](#tensorboard-visualization) * [Contributing](#contributing) * [TODOs](#todos) * [License](#license) @@ -36,8 +36,8 @@ PyTorch deep learning project made easy. * Python >= 3.5 (3.6 recommended) * PyTorch >= 0.4 * tqdm (Optional for `test.py`) -* tensorboard >= 1.7.0 (Optional for TensorboardX) -* tensorboardX >= 1.2 (Optional for TensorboardX) +* tensorboard >= 1.7.0 (Optional for TensorboardX) or tensorboard >= 1.14 (Optional for pytorch.utils.tensorboard) +* tensorboardX >= 1.2 (Optional for TensorboardX), see [Tensorboard Visualization][#tensorboardx-visualization] ## Features * Clear folder structure which is suitable for many deep learning projects. @@ -329,8 +329,11 @@ A copy of config file will be saved in the same folder. } ``` -### TensorboardX Visualization -This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) visualization. +### Tensorboard Visualization +This template supports Tensorboard visualization using either Pytorch 1.1's `torch.utils.tensorboard` capabilities or [TensorboardX](https://github.com/lanpa/tensorboardX). + +The template attempts to choose a writing module from a list of modules specified in the config file under "tensorboard.modules". It load the modules in the order specified, only moving on to the next one if the previous one failed. + * **TensorboardX Usage** 1. **Install** @@ -339,17 +342,44 @@ This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) vis 2. **Run training** - Set `tensorboardX` option in config file true. + Set `tensorboard` option in config file to: + Set the "tensorboard" entry in the config to: + ``` + "tensorboard" :{ + "enabled": true, + "modules": ["tensorboardX", "torch.utils.tensorboard"] + } + ``` -3. **Open tensorboard server** +3. **Open Tensorboard server** Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006` +* **Pytorch 1.1 torch.utils.tensorboard Usage** + +1. **Install** + + Must have Pytorch 1.1 installed and `tensorboard >= 1.14` (`pip install tb-nightly`). + +2. **Run training** + + Set the "tensorboard" entry in the config to: + ``` + "tensorboard" :{ + "enabled": true, + "modules": ["torch.utils.tensorboard", "tensorboardX"] + } + ``` + +3. **Open Tensorboard server** + + Same as above. + By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged. If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method. -`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` module. +`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules. -**Note**: You don't have to specify current steps, since `WriterTensorboardX` class defined at `logger/visualization.py` will track current steps. +**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps. ## Contributing Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8 diff --git a/base/base_trainer.py b/base/base_trainer.py index 85fba032..1e249ba3 100644 --- a/base/base_trainer.py +++ b/base/base_trainer.py @@ -1,7 +1,7 @@ import torch from abc import abstractmethod from numpy import inf -from logger import WriterTensorboardX +from logger import TensorboardWriter class BaseTrainer: @@ -41,8 +41,9 @@ def __init__(self, model, loss, metrics, optimizer, config): self.start_epoch = 1 self.checkpoint_dir = config.save_dir - # setup visualization writer instance - self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX']) + + # setup visualization writer instance + self.writer = TensorboardWriter(config.log_dir, self.logger, config['tensorboard']) if config.resume is not None: self._resume_checkpoint(config.resume) diff --git a/config.json b/config.json index 4dd7be3b..fb00bf09 100644 --- a/config.json +++ b/config.json @@ -43,8 +43,10 @@ "verbosity": 2, "monitor": "min val_loss", - "early_stop": 10, - - "tensorboardX": true + "early_stop": 10 + }, + "tensorboard" :{ + "enabled": true, + "modules": ["tensorboardX", "torch.utils.tensorboard"] } } diff --git a/logger/visualization.py b/logger/visualization.py index 94f45f1d..b3630d3d 100644 --- a/logger/visualization.py +++ b/logger/visualization.py @@ -2,18 +2,36 @@ from utils import Timer -class WriterTensorboardX(): - def __init__(self, log_dir, logger, enable): +class TensorboardWriter(): + def __init__(self, log_dir, logger, config): self.writer = None - if enable: + + self.viz_methods = ["pytorch_tensorboard", "tensorboardX"] + self.selected_module = "" + + if config["enabled"]: log_dir = str(log_dir) - try: - self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_dir) - except ImportError: - message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \ - "this machine. Please install the package by 'pip install tensorboardx' command or turn " \ - "off the option in the 'config.json' file." + + # Try to find a vizualization writer. + succeeded = False + for module in config["modules"]: + try: + self.writer = importlib.import_module(module).SummaryWriter(log_dir) + succeeded = True + self.selected_module = module + logger.info("Selected Tensorboard writer {}".format(module)) + break + except ImportError: + logger.warning("{} failed to load.".format(module)) + succeeded = False + + if (not succeeded): + message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ + "this machine. Please install either TensorboardX with 'pip install tensorboardx', " \ + "install PyTorch 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ + "the 'config.json' file." logger.warning(message) + self.step = 0 self.mode = '' @@ -22,6 +40,11 @@ def __init__(self, log_dir, logger, enable): 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' ] self.tag_mode_exceptions = ['add_histogram', 'add_embedding'] + + if(self.selected_module == "pytorch.utils.tensorboard"): + self.tb_writer_ftns = self.tb_writer_ftns + self.tag_mode_exceptions + self.tag_mode_exceptions = [] + self.timer = Timer() def set_step(self, step, mode='train'): @@ -55,5 +78,5 @@ def wrapper(tag, data, *args, **kwargs): try: attr = object.__getattr__(name) except AttributeError: - raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name)) + raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) return attr