Skip to content

Switch DistStrat revised API examples to TensorFlow 2 style. #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 92 additions & 82 deletions rfcs/20181016-replicator.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,76 +691,84 @@ Below is a simple usage example for an image classification use case.

```python
with strategy.scope():
model = resnet.ResNetV1(resnet.BLOCKS_50)
optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
model = tf.keras.applications.ResNet50(weights=None)
optimizer = tf.keras.optimizers.SGD(learning_rate, momentum=0.9)

def input_fn(ctx):
return imagenet.ImageNet(ctx.get_per_replica_batch_size(effective_batch_size))

def step_fn(inputs):
image, label = inputs
input_iterator = strategy.make_input_iterator(input_fn)

logits = model(images)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label)
loss = tf.reduce_mean(cross_entropy)
train_op = optimizer.minimize(loss)
with tf.control_dependencies([train_op]):
return tf.identity(loss)
@tf.function
def train_step():
def step_fn(inputs):
image, label = inputs

input_iterator = strategy.make_input_iterator(input_fn)
per_replica_losses = strategy.run(step_fn, input_iterator)
mean_loss = strategy.reduce(per_replica_losses)

with tf.Session(config=session_config) as session:
session.run(strategy.initialize())
session.run(input_iterator.initialize())
for _ in range(num_train_steps):
loss = session.run(mean_loss)
session.run(strategy.finalize())
with tf.GradientTape() as tape:
logits = model(images)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=label)
loss = tf.reduce_mean(cross_entropy)

grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
return loss

per_replica_losses = strategy.run(step_fn, input_iterator)
mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses)
return mean_loss

strategy.initialize()
input_iterator.initialize()
for _ in range(num_train_steps):
loss = train_step()
strategy.finalize()
```

#### Evaluation

```python
with strategy.scope():
model = resnet.ResNetV1(resnet.BLOCKS_50)
model = tf.keras.applications.ResNet50(weights=None)

def eval_input_fn(ctx):
del ctx # Unused.
return imagenet.ImageNet(
eval_batch_size, subset="valid", shuffle=False, num_epochs=1)

def eval_top1_accuracy(inputs):
image, label = inputs
logits = model(images)
predicted_label = tf.argmax(logits, axis=1)
top_1_acc = tf.reduce_mean(
tf.cast(tf.equal(predicted_label, label), tf.float32))
return top1_acc

eval_input_iterator = strategy.make_input_iterator(
eval_input_fn, input_replication_mode=InputReplicationMode.SINGLE)
per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
mean_top1_acc = strategy.reduce(per_replica_top1_accs)

with tf.Session(config=session_config) as session:
session.run(strategy.initialize())
@tf.function
def eval():
def eval_top1_accuracy(inputs):
image, label = inputs
logits = model(images)
predicted_label = tf.argmax(logits, axis=1)
top_1_acc = tf.reduce_mean(
tf.cast(tf.equal(predicted_label, label), tf.float32))
return top1_acc

per_replica_top1_accs = strategy.run(eval_top1_accuracy, eval_input_iterator)
mean_top1_acc = strategy.reduce(AggregationType.MEAN, per_replica_top1_accs)
return mean_top1_acc

strategy.initialize()
while True:
while not has_new_checkpoint():
sleep(60)

load_checkpoint()

# Do a sweep over the entire validation set.
eval_input_iterator.initialize()
while True:
while not has_new_checkpoint():
sleep(60)

load_checkpoint()

# Do a sweep over the entire validation set.
session.run(eval_input_iterator.initialize())
while True:
try:
top1_acc = session.run(mean_top1_acc)
...
except tf.errors.OutOfRangeError:
break
session.run(strategy.finalize())
try:
top1_acc = eval()
...
except tf.errors.OutOfRangeError:
break
strategy.finalize()
```

#### Sharded Input Pipeline
Expand Down Expand Up @@ -801,42 +809,43 @@ with strategy.scope():
discriminator = GoodfellowDiscriminator(DefaultDiscriminator2D())
generator = DefaultGenerator2D()
gan = GAN(discriminator, generator)
disc_optimizer = tf.train.AdamOptimizer(disc_learning_rate, beta1=0.5, beta2=0.9)
gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate, beta1=0.5, beta2=0.9)
disc_optimizer = tf.keras.optimizers.Adam(disc_learning_rate)
gen_optimizer = tf.keras.optimizers.Adam(gen_learning_rate)

def discriminator_step(inputs):
image, noise = inputs
gan_output = gan.connect(image, noise)
disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()
disc_train_op = disc_optimizer.minimize(disc_loss, var_list=disc_vars)

with tf.control_dependencies([disc_train_op]):
return tf.identity(disc_loss)

with tf.GradientTape() as tape:
gan_output = gan.connect(image, noise)
disc_loss, disc_vars = gan_output.discriminator_loss_and_vars()

grads = tape.gradients(disc_loss, disc_vars)
disc_optimizer.apply_gradients(list(zip(grads, disc_vars)))
return disc_loss

def generator_step(inputs):
image, noise = inputs
gan_output = gan.connect(image, noise)
gen_loss, gen_vars = gan_output.generator_loss_and_vars()
gen_train_op = gen_optimizer.minimize(gen_loss, var_list=gen_vars)

with tf.control_dependencies([gen_train_op]):
return tf.identity(gen_loss)

with tf.GradientTape() as tape:
gan_output = gan.connect(image, noise)
gen_loss, gen_vars = gan_output.generator_loss_and_vars()

grads = tape.gradient(gen_loss, gen_vars)
gen_optimizer.apply_gradients(list(zip(grads, gen_vars)))
return gen_loss

input_iterator = strategy.make_input_iterator(input_fn)
per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
per_replica_gen_losses = strategy.run(generator_step, input_iterator)
mean_disc_loss = strategy.reduce(per_replica_disc_losses)
mean_gen_loss = strategy.reduce(per_replica_gen_losses)

with tf.Session() as session:
session.run(strategy.initialize())
session.run(input_iterator.initialize())
for _ in range(num_train_steps):
for _ in range(num_disc_steps):
disc_loss = session.run(mean_disc_loss)
for _ in range(num_gen_steps):
gen_loss = session.run(mean_gen_loss)
session.run(strategy.finalize())

strategy.initialize()
input_iterator.initialize()
for _ in range(num_train_steps):
for _ in range(num_disc_steps):
per_replica_disc_losses = strategy.run(discriminator_step, input_iterator)
mean_disc_loss = strategy.reduce(AggregationType.MEAN, per_replica_disc_losses)
for _ in range(num_gen_steps):
per_replica_gen_losses = strategy.run(generator_step, input_iterator)
mean_gen_loss = strategy.reduce(AggregationType.MEAN, per_replica_gen_losses)
strategy.finalize()
```

### Reinforcement Learning
Expand All @@ -846,11 +855,9 @@ This is an example of
Reinforcement Learning system, converted to eager style.

```python
tf.enable_eager_execution()

with strategy.scope():
agent = Agent(num_actions, hidden_size, entropy_cost, baseline_cost)
optimizer = tf.train.RMSPropOptimizer(learning_rate)
optimizer = tf.keras.optimizers.RMSprop(learning_rate)

# Queues of trajectories from actors.
queues = []
Expand All @@ -867,9 +874,12 @@ def learner_input(ctx):
return dequeue_batch

def learner_step(trajectories):
loss = tf.reduce_sum(agent.compute_loss(trajectories))
with tf.GradientTape() as tape:
loss = tf.reduce_sum(agent.compute_loss(trajectories))

agent_vars = agent.get_all_variables()
optimizer.minimize(loss, var_list=agent_vars)
grads = tape.gradient(loss, agent_vars)
optimizer.apply_gradients(list(zip(grads, agent_vars)))
return loss, agent_vars

# Create learner inputs.
Expand All @@ -893,7 +903,7 @@ strategy.initialize()
for _ in range(num_train_steps):
per_replica_outputs = strategy.run(learner_step, learner_inputs)
per_replica_losses, updated_agent_var_copies = zip(*per_replica_outputs)
mean_loss = strategy.reduce(per_replica_losses)
mean_loss = strategy.reduce(AggregationType.MEAN, per_replica_losses)

strategy.finalize()
```
Expand Down