Skip to content

An attempt on creating a best practices TensorFlow project.

Notifications You must be signed in to change notification settings

sunsided/tensorflow-scaffold

Repository files navigation

TensorFlow Project Scaffold

This project is meant to provide a starting point for new TensorFlow projects. It showcases

Inspirations and sources:

Structure of the project

  • project: project modules such as networks, input pipelines, etc.
  • library: scripts and boilerplate code

Two configuration files exist:

  • project.yaml: Serialized command-line options
  • hyperparameters.yaml: Model hyperparameters

Here's an example hyperparameters.yaml, with a default hyper-parameter set (conveniently called default), and an additional set named mobilenet. Here, the mobilenet set inherits from default and overwrites only the default parameters with the newly defined ones.

default: &DEFAULT
  # batch_size: 100
  # num_epoch: 1000
  # optimizer: Adam
  learning_rate: 1e-4
  dropout_rate: 0.5
  l2_regularization: 1e-8
  xentropy_label_smoothing: 0.
  adam_beta1: 0.9
  adam_beta2: 0.999
  adam_epsilon: 1e-8

mobilenet:
  <<: *DEFAULT
  learning_rate: 1e-5
  fine_tuning: True

Likewise, the project.yaml contains serialized command-line parameters:

default: &DEFAULT
  train_batch_size: 32
  train_epochs: 1000
  epochs_between_evals: 100
  hyperparameter_file: hyperparameters.yaml
  hyperparameter_set: default
  model: latest
  model_dir: out/current/checkpoints
  best_model_dir: out/current/best

gtx1080ti:
  <<: *DEFAULT
  train_batch_size: 512

thinkpadx201t:
  <<: *DEFAULT
  train_batch_size: 10
  train_epochs: 10
  epochs_between_evals: 1
  random_seed: 0

By selecting a configuration set on startup using the --config_set command-line option, best configurations can be stored and versioned easily. Configuration provided on the command-line overrides values defined in project.yaml, allowing for quick iteration.

Run training

In order to run a training session (manually overriding configuration from project.yaml), try

python run.py \
    --xla \
    --epochs_between_evals 1000 \
    --train_epochs 10000 \
    --learning_rate 0.0001 

Prepare the dataset

In order to improve processing speed later on, the image files are converted to TFRecord format first. For this, run

python convert_dataset.py \
    --dataset_dir dataset/train \
    --tfrecord_filename train \
    --tfrecord_dir dataset/train \
    --max_edge 384
python convert_dataset.py \
    --dataset_dir dataset/test \
    --tfrecord_filename test \
    --tfrecord_dir dataset/test \
    --max_edge 384

This example stores image data as JPEG encoded raw bytes and decodes them on the fly in the input pipeline. While this leads to much smaller TFRecord files compared to storing raw pixel values, it also creates a (noticeable) latency. There's a tradeoff here.

TensorFlow Hub

In order to use TensorFlow Hub, install it using e.g.

pip install tensorflow-hub

When initializing a Conda environment from environment.yaml, this is already taken care of.

About

An attempt on creating a best practices TensorFlow project.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published