Skip to content

Commit

Permalink
Merge pull request victoresque#55 from christopherbate/master
Browse files Browse the repository at this point in the history
Add pytorch 1.1 utils.tensorboard support.
  • Loading branch information
SunQpark authored Jun 11, 2019
2 parents 83d9038 + c04075a commit 41505c2
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 27 deletions.
48 changes: 39 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ PyTorch deep learning project made easy.
* [Additional logging](#additional-logging)
* [Validation data](#validation-data)
* [Checkpoints](#checkpoints)
* [TensorboardX Visualization](#tensorboardx-visualization)
* [Tensorboard Visualization](#tensorboard-visualization)
* [Contributing](#contributing)
* [TODOs](#todos)
* [License](#license)
Expand All @@ -36,8 +36,8 @@ PyTorch deep learning project made easy.
* Python >= 3.5 (3.6 recommended)
* PyTorch >= 0.4
* tqdm (Optional for `test.py`)
* tensorboard >= 1.7.0 (Optional for TensorboardX)
* tensorboardX >= 1.2 (Optional for TensorboardX)
* tensorboard >= 1.7.0 (Optional for TensorboardX) or tensorboard >= 1.14 (Optional for pytorch.utils.tensorboard)
* tensorboardX >= 1.2 (Optional for TensorboardX), see [Tensorboard Visualization][#tensorboardx-visualization]

## Features
* Clear folder structure which is suitable for many deep learning projects.
Expand Down Expand Up @@ -329,8 +329,11 @@ A copy of config file will be saved in the same folder.
}
```

### TensorboardX Visualization
This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) visualization.
### Tensorboard Visualization
This template supports Tensorboard visualization using either Pytorch 1.1's `torch.utils.tensorboard` capabilities or [TensorboardX](https://github.com/lanpa/tensorboardX).

The template attempts to choose a writing module from a list of modules specified in the config file under "tensorboard.modules". It load the modules in the order specified, only moving on to the next one if the previous one failed.

* **TensorboardX Usage**

1. **Install**
Expand All @@ -339,17 +342,44 @@ This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) vis

2. **Run training**

Set `tensorboardX` option in config file true.
Set `tensorboard` option in config file to:
Set the "tensorboard" entry in the config to:
```
"tensorboard" :{
"enabled": true,
"modules": ["tensorboardX", "torch.utils.tensorboard"]
}
```
3. **Open tensorboard server**
3. **Open Tensorboard server**
Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
* **Pytorch 1.1 torch.utils.tensorboard Usage**
1. **Install**
Must have Pytorch 1.1 installed and `tensorboard >= 1.14` (`pip install tb-nightly`).
2. **Run training**
Set the "tensorboard" entry in the config to:
```
"tensorboard" :{
"enabled": true,
"modules": ["torch.utils.tensorboard", "tensorboardX"]
}
```
3. **Open Tensorboard server**
Same as above.
By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` module.
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
**Note**: You don't have to specify current steps, since `WriterTensorboardX` class defined at `logger/visualization.py` will track current steps.
**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
## Contributing
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
Expand Down
7 changes: 4 additions & 3 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from abc import abstractmethod
from numpy import inf
from logger import WriterTensorboardX
from logger import TensorboardWriter


class BaseTrainer:
Expand Down Expand Up @@ -41,8 +41,9 @@ def __init__(self, model, loss, metrics, optimizer, config):
self.start_epoch = 1

self.checkpoint_dir = config.save_dir
# setup visualization writer instance
self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX'])

# setup visualization writer instance
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

if config.resume is not None:
self._resume_checkpoint(config.resume)
Expand Down
4 changes: 2 additions & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

"monitor": "min val_loss",
"early_stop": 10,
"tensorboardX": true

"tensorboard": true
}
}
41 changes: 28 additions & 13 deletions logger/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,41 @@
from utils import Timer


class WriterTensorboardX():
def __init__(self, log_dir, logger, enable):
class TensorboardWriter():
def __init__(self, log_dir, logger, enabled):
self.writer = None
if enable:
self.selected_module = ""

if enabled:
log_dir = str(log_dir)
try:
self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_dir)
except ImportError:
message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \
"this machine. Please install the package by 'pip install tensorboardx' command or turn " \
"off the option in the 'config.json' file."

# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module

if not succeeded:
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
"PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
"the 'config.json' file."
logger.warning(message)

self.step = 0
self.mode = ''

self.tb_writer_ftns = [
self.tb_writer_ftns = {
'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
]
self.tag_mode_exceptions = ['add_histogram', 'add_embedding']
}
self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}

self.timer = Timer()

def set_step(self, step, mode='train'):
Expand Down Expand Up @@ -55,5 +70,5 @@ def wrapper(tag, data, *args, **kwargs):
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name))
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr

0 comments on commit 41505c2

Please sign in to comment.