Skip to content

Commit

Permalink
Add pytorch 1.1 utils.tensorboard support.
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherbate committed Jun 1, 2019
1 parent 83d9038 commit 7857fd3
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 25 deletions.
48 changes: 39 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ PyTorch deep learning project made easy.
* [Additional logging](#additional-logging)
* [Validation data](#validation-data)
* [Checkpoints](#checkpoints)
* [TensorboardX Visualization](#tensorboardx-visualization)
* [Tensorboard Visualization](#tensorboard-visualization)
* [Contributing](#contributing)
* [TODOs](#todos)
* [License](#license)
Expand All @@ -36,8 +36,8 @@ PyTorch deep learning project made easy.
* Python >= 3.5 (3.6 recommended)
* PyTorch >= 0.4
* tqdm (Optional for `test.py`)
* tensorboard >= 1.7.0 (Optional for TensorboardX)
* tensorboardX >= 1.2 (Optional for TensorboardX)
* tensorboard >= 1.7.0 (Optional for TensorboardX) or tensorboard >= 1.14 (Optional for pytorch.utils.tensorboard)
* tensorboardX >= 1.2 (Optional for TensorboardX), see [Tensorboard Visualization][#tensorboardx-visualization]

## Features
* Clear folder structure which is suitable for many deep learning projects.
Expand Down Expand Up @@ -329,8 +329,11 @@ A copy of config file will be saved in the same folder.
}
```

### TensorboardX Visualization
This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) visualization.
### Tensorboard Visualization
This template supports Tensorboard visualization using either Pytorch 1.1's `torch.utils.tensorboard` capabilities or [TensorboardX](https://github.com/lanpa/tensorboardX).

The template attempts to choose a writing module from a list of modules specified in the config file under "tensorboard.modules". It load the modules in the order specified, only moving on to the next one if the previous one failed.

* **TensorboardX Usage**

1. **Install**
Expand All @@ -339,17 +342,44 @@ This template supports [TensorboardX](https://github.com/lanpa/tensorboardX) vis

2. **Run training**

Set `tensorboardX` option in config file true.
Set `tensorboard` option in config file to:
Set the "tensorboard" entry in the config to:
```
"tensorboard" :{
"enabled": true,
"modules": ["tensorboardX", "torch.utils.tensorboard"]
}
```
3. **Open tensorboard server**
3. **Open Tensorboard server**
Type `tensorboard --logdir saved/log/` at the project root, then server will open at `http://localhost:6006`
* **Pytorch 1.1 torch.utils.tensorboard Usage**
1. **Install**
Must have Pytorch 1.1 installed and `tensorboard >= 1.14` (`pip install tb-nightly`).
2. **Run training**
Set the "tensorboard" entry in the config to:
```
"tensorboard" :{
"enabled": true,
"modules": ["torch.utils.tensorboard", "tensorboardX"]
}
```
3. **Open Tensorboard server**
Same as above.
By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged.
If you need more visualizations, use `add_scalar('tag', data)`, `add_image('tag', image)`, etc in the `trainer._train_epoch` method.
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` module.
`add_something()` methods in this template are basically wrappers for those of `tensorboardX.SummaryWriter` and `torch.utils.tensorboard.SummaryWriter` modules.
**Note**: You don't have to specify current steps, since `WriterTensorboardX` class defined at `logger/visualization.py` will track current steps.
**Note**: You don't have to specify current steps, since `WriterTensorboard` class defined at `logger/visualization.py` will track current steps.
## Contributing
Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8
Expand Down
7 changes: 4 additions & 3 deletions base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from abc import abstractmethod
from numpy import inf
from logger import WriterTensorboardX
from logger import TensorboardWriter


class BaseTrainer:
Expand Down Expand Up @@ -41,8 +41,9 @@ def __init__(self, model, loss, metrics, optimizer, config):
self.start_epoch = 1

self.checkpoint_dir = config.save_dir
# setup visualization writer instance
self.writer = WriterTensorboardX(config.log_dir, self.logger, cfg_trainer['tensorboardX'])

# setup visualization writer instance
self.writer = TensorboardWriter(config.log_dir, self.logger, config['tensorboard'])

if config.resume is not None:
self._resume_checkpoint(config.resume)
Expand Down
8 changes: 5 additions & 3 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
"verbosity": 2,

"monitor": "min val_loss",
"early_stop": 10,

"tensorboardX": true
"early_stop": 10
},
"tensorboard" :{
"enabled": true,
"modules": ["tensorboardX", "torch.utils.tensorboard"]
}
}
43 changes: 33 additions & 10 deletions logger/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,36 @@
from utils import Timer


class WriterTensorboardX():
def __init__(self, log_dir, logger, enable):
class TensorboardWriter():
def __init__(self, log_dir, logger, config):
self.writer = None
if enable:

self.viz_methods = ["pytorch_tensorboard", "tensorboardX"]
self.selected_module = ""

if config["enabled"]:
log_dir = str(log_dir)
try:
self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_dir)
except ImportError:
message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \
"this machine. Please install the package by 'pip install tensorboardx' command or turn " \
"off the option in the 'config.json' file."

# Try to find a vizualization writer.
succeeded = False
for module in config["modules"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
self.selected_module = module
logger.info("Selected Tensorboard writer {}".format(module))
break
except ImportError:
logger.warning("{} failed to load.".format(module))
succeeded = False

if (not succeeded):
message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
"this machine. Please install either TensorboardX with 'pip install tensorboardx', " \
"install PyTorch 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
"the 'config.json' file."
logger.warning(message)

self.step = 0
self.mode = ''

Expand All @@ -22,6 +40,11 @@ def __init__(self, log_dir, logger, enable):
'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
]
self.tag_mode_exceptions = ['add_histogram', 'add_embedding']

if(self.selected_module == "pytorch.utils.tensorboard"):
self.tb_writer_ftns = self.tb_writer_ftns + self.tag_mode_exceptions
self.tag_mode_exceptions = []

self.timer = Timer()

def set_step(self, step, mode='train'):
Expand Down Expand Up @@ -55,5 +78,5 @@ def wrapper(tag, data, *args, **kwargs):
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name))
raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
return attr

0 comments on commit 7857fd3

Please sign in to comment.