Skip to content

Latest commit

 

History

History

tpu

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

TPU support for TensorFlow

This directory contains code required to re-target a TensorFlow model to run on TPUs.

Example usage - TPU Estimator

Below shows example usage of the TPU Estimator for a simple convolutional network.

import tensorflow as tf

from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer

def model_fn(features, labels, mode, params):
  # Define the model to construct the logits
  logits = # ...
  loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
  optimizer = tpu_optimizer.CrossShardOptimizer(
    tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate))
  train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

def input_fn(params):
  # ...
  pass

def main():
  run_config = tpu_config.RunConfig(
    master=FLAGS.master,
    # ...
  )
  estimator = tpu_estimator.TpuEstimator(
    model_fn=model_fn,
    use_tpu=FLAGS.use_tpu,
    config=run_config,
    batch_size=FLAGS.batch_size)
  estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)

For the complete [executable] example, see our open source TPU models.