This directory contains code required to re-target a TensorFlow model to run on TPUs.
Below shows example usage of the TPU Estimator for a simple convolutional network.
import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
def model_fn(features, labels, mode, params):
# Define the model to construct the logits
logits = # ...
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
optimizer = tpu_optimizer.CrossShardOptimizer(
tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate))
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
def input_fn(params):
# ...
pass
def main():
run_config = tpu_config.RunConfig(
master=FLAGS.master,
# ...
)
estimator = tpu_estimator.TpuEstimator(
model_fn=model_fn,
use_tpu=FLAGS.use_tpu,
config=run_config,
batch_size=FLAGS.batch_size)
estimator.train(input_fn=input_fn, max_steps=FLAGS.train_steps)
For the complete [executable] example, see our open source TPU models.