This project is meant to provide a starting point for new TensorFlow projects. It showcases
tf.estimator.Estimator
-based training using custominput_fn
andmodel_fn
functions, using standardtf.estimator.EstimatorSpec
definitions.- Image files are read using
tf.gfile.FastGFile
for source-agnostic, lock-free file loading. - JPEGs are decoding efficiently using
tf.image.decode_and_crop_jpeg
.
- Image files are read using
- Usage of pretrained models using
tensorflow_hub.Module
. tf.data.Dataset
with.list_files()
and.from_generator()
examples.- Interleaved
TFRecord
input streams usingtf.data.TFRecordDataset
andtf.contrib.data.parallel_interleave
. - GPU prefetching using
tf.contrib.data.prefetch_to_device
.
- Interleaved
- Automatic snapshotting of parameters with the best
validation loss into a separate directory using a custom
SessionRunHook
.
Inspirations and sources:
- Importing Data
- Input Pipeline Performance Guide
- Preparing a large-scale image dataset with TensorFlow's TFRecord files
- Getting Text into Tensorflow with the Dataset API
- How to write into and read from a TFRecords file in TensorFlow
- Use HParams and YAML to Better Manage Hyperparameters in Tensorflow
- generator-tf
project
: project modules such as networks, input pipelines, etc.library
: scripts and boilerplate code
Two configuration files exist:
project.yaml
: Serialized command-line optionshyperparameters.yaml
: Model hyperparameters
Here's an example hyperparameters.yaml
, with a default hyper-parameter
set (conveniently called default
), and an additional set named mobilenet
.
Here, the mobilenet
set inherits from default
and overwrites
only the default parameters with the newly defined ones.
default: &DEFAULT
# batch_size: 100
# num_epoch: 1000
# optimizer: Adam
learning_rate: 1e-4
dropout_rate: 0.5
l2_regularization: 1e-8
xentropy_label_smoothing: 0.
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1e-8
mobilenet:
<<: *DEFAULT
learning_rate: 1e-5
fine_tuning: True
Likewise, the project.yaml
contains serialized command-line
parameters:
default: &DEFAULT
train_batch_size: 32
train_epochs: 1000
epochs_between_evals: 100
hyperparameter_file: hyperparameters.yaml
hyperparameter_set: default
model: latest
model_dir: out/current/checkpoints
best_model_dir: out/current/best
gtx1080ti:
<<: *DEFAULT
train_batch_size: 512
thinkpadx201t:
<<: *DEFAULT
train_batch_size: 10
train_epochs: 10
epochs_between_evals: 1
random_seed: 0
By selecting a configuration set on startup using the --config_set
command-line
option, best configurations can be stored and versioned easily.
Configuration provided on the command-line overrides values defined
in project.yaml
, allowing for quick iteration.
In order to run a training session (manually overriding configuration
from project.yaml
), try
python run.py \
--xla \
--epochs_between_evals 1000 \
--train_epochs 10000 \
--learning_rate 0.0001
In order to improve processing speed later on, the image files are
converted to TFRecord
format first. For this, run
python convert_dataset.py \
--dataset_dir dataset/train \
--tfrecord_filename train \
--tfrecord_dir dataset/train \
--max_edge 384
python convert_dataset.py \
--dataset_dir dataset/test \
--tfrecord_filename test \
--tfrecord_dir dataset/test \
--max_edge 384
This example stores image data as JPEG encoded raw bytes and decodes them on the fly in the input pipeline. While this leads to much smaller TFRecord files compared to storing raw pixel values, it also creates a (noticeable) latency. There's a tradeoff here.
In order to use TensorFlow Hub, install it using e.g.
pip install tensorflow-hub
When initializing a Conda environment from environment.yaml
, this is
already taken care of.