Skip to content

A TensorFlow implementation of the recurrent models of visual attention

License

Notifications You must be signed in to change notification settings

gklezd/recurrent-attention-model

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Recurrent Visual Attention Model (RAM)

  • A TensorFlow implementation of Recurrent Models of Visual Attention (NIPS 14).
  • The the model from the paper: ram
  • For image classification/recognition tasks, this model will look at different location of the original image for several steps before making the final prediction. Besides reduces the computational complexity, this type of model also privodes more explainable results.
  • The core network is an RNN. It takes the representation from glimpse network and previous hidden state as input and then produces the new hidden state.
  • The location network and action network takes the new hidden state of the RNN as input and produce new location for the glimpse network to extract the new representation and the new action, respectively.
  • This repository includes experiments on the original MNIST and the translated MNIST dataset. The classification accuracy on the translated MNIST dataset is 97.82%.

Requirements

Implementation Details

  • The RAM model is defined in lib/model/ram.py.
  • An example of MNIST dataset classification is in example/mnist.py.
  • The action network, core network and glimpse network are trained by the classification cross entropy loss.
  • The location network is trained through the algorithm REINFORCE.
  • The reward baseline network is trained by the MSE loss between baseline and reward.

Result

Centered MNIST

Hyperparameters

name value name value
step 6
smallest glimpse 8 glimpse scales 1
location std 0.03 pixel to unit width 12
Batch Size 128 Max Epoch 1000
Learning Rate 1e-3 (initial) and decay with 0.97 for every 500 steps
  • Locations of glimpse at different steps center

Translated MNIST (accuracy 97.82%)

Hyperparameters

name value name value
image size 60 x 60 step 6
smallest glimpse 12 glimpse scales 3
location std 0.03 pixel to unit width 26
Batch Size 128 Max Epoch 2000
Learning Rate 1e-3 (initial) and decay with 0.97 for every 500 steps
  • Locations of glimpse at different steps trans

Usage

  • Download the MNIST dataset from here.
  • Setup path in example/read_mnist.py: DATA_PATH is the directory to put MNIST dataset.
  • Setup path in example/mnist.py: SAVE_PATH is the directory to save trained models during training. RESULT_PATH is the directory to save prediction results.

Argument

Run the script examples/read_mnist.py to train and test the model. Here are all the arguments:

  • --train: Train the model.
  • --eval: Evaluate the model.
  • --predict: Get output images with glimpse positions step by step on samples from testing set using the trained model.
  • --batch: Batch size. Default: 128.
  • --epoch: Max number of epochs. Default: 1000.
  • --lr: Initial learning rate. Default: 1e-3.
  • --load: The epoch ID of trained model to be restored for evaluation or prediction.
  • --dataset: Dataset used for experiment. Default: center. center for original MNIST. translate for translated MNIST with size 60 x 60.
  • --step: Number of times the model looking at images before the final prediction. Default: 1
  • --glimpse: Size of smallest glimpse. Default: 12
  • --scale: Number of glimpse scales. Default: 3.
  • --std: Standard deviation for sampling glimpse locations. Default: 0.11.
  • --pixel: Unit width in the coordinate system of the model to the number of pixels. Default: 26.

Train the model

  • Go to examples/, then run
python --train --dataset DATASET
  • --dataset center for original MNIST. --dataset translate for translated MNIST with size 60 x 60.

Evaluate the model

  • Go to examples/, then run
python --eval --dataset DATASET --load RESTORE_ID
  • The accuarcy on testing set will be printed out.

Reference implementation:

About

A TensorFlow implementation of the recurrent models of visual attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.8%
  • MATLAB 2.2%