Skip to content

Commit

Permalink
Moved test out of run_all_in_graph_and_eager_mode sparse_image_wrap. (t…
Browse files Browse the repository at this point in the history
…ensorflow#1430)

* Moved test out of run_all_in_graph_and_eager_mode sparse_image_wrap.

See tensorflow#1328

* Update tensorflow_addons/image/sparse_image_warp_test.py
  • Loading branch information
gabrieldemarmiesse authored Mar 27, 2020
1 parent 665b750 commit 249c118
Showing 1 changed file with 26 additions and 33 deletions.
59 changes: 26 additions & 33 deletions tensorflow_addons/image/sparse_image_warp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,42 +231,35 @@ def testSmileyFace(self):
# than that in the saved png file loaded into target_image.
self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3)

def testThatBackpropRuns(self):
"""Run optimization to ensure that gradients can be computed."""
batch_size = 1
image_height = 9
image_width = 12
image = tf.Variable(
np.random.uniform(size=[batch_size, image_height, image_width, 3]),
dtype=tf.float32,
)
control_point_locations = [[3.0, 3.0]]
control_point_locations = tf.constant(
np.float32(np.expand_dims(control_point_locations, 0))
)
control_point_displacements = [[0.25, -0.5]]
control_point_displacements = tf.constant(
np.float32(np.expand_dims(control_point_displacements, 0))
)

def loss_fn():
warped_image, _ = sparse_image_warp(
image,
control_point_locations,
control_point_locations + control_point_displacements,
num_boundary_points=3,
)
loss = tf.reduce_mean(tf.abs(warped_image - image))
return loss

optimizer = tf.keras.optimizers.SGD(
learning_rate=0.001, momentum=0.9, clipnorm=1.0
def test_that_backprop_runs():
"""Making sure the gradients can be computed."""
batch_size = 1
image_height = 9
image_width = 12
image = tf.Variable(
np.random.uniform(size=[batch_size, image_height, image_width, 3]),
dtype=tf.float32,
)
control_point_locations = [[3.0, 3.0]]
control_point_locations = tf.constant(
np.float32(np.expand_dims(control_point_locations, 0))
)
control_point_displacements = [[0.25, -0.5]]
control_point_displacements = tf.constant(
np.float32(np.expand_dims(control_point_displacements, 0))
)

with tf.GradientTape() as t:
warped_image, _ = sparse_image_warp(
image,
control_point_locations,
control_point_locations + control_point_displacements,
num_boundary_points=3,
)
opt_op = optimizer.minimize(loss_fn, [image])

self.evaluate(tf.compat.v1.global_variables_initializer())
for _ in range(5):
self.evaluate(opt_op)
gradients = t.gradient(warped_image, image).numpy()
assert np.sum(np.abs(gradients)) != 0


if __name__ == "__main__":
Expand Down

0 comments on commit 249c118

Please sign in to comment.