Skip to content

Commit

Permalink
Automated rollback of commit acab6a2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 258246018
  • Loading branch information
allenlavoie authored and tensorflower-gardener committed Jul 15, 2019
1 parent 7a2ce46 commit 68e2db3
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 21 deletions.
48 changes: 42 additions & 6 deletions tensorflow/c/eager/tape.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ class GradientTape {
bool persistent_;
};

// Describes a callback for special-cased and more efficient jvp computation.
//
// Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on
// that.
template <typename Gradient>
class ForwardFunction
: public std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)> {
public:
template <typename lambda_type>
explicit ForwardFunction(lambda_type lambda)
: std::function<Status(const std::vector<Gradient*>&,
std::vector<Gradient*>*)>(lambda) {}
};

// Computes Jacobian-vector products using forward-mode automatic
// differentiation.
//
Expand Down Expand Up @@ -222,6 +237,12 @@ class ForwardAccumulator {
// between calls to ShouldRecord and Accumulator), and its outputs
// (`output_tensors`).
//
// If provided, a non-null `forward_function` will be used instead of the
// backward function (`backward_function_getter` /
// `backward_function_deleter`) to compute jvps for this operation. If
// `forward_function` is null, a GradientTape is used on the backward function
// to compute the jvp, which will waste computation when executing eagerly.
//
// Unlike GradientTape::RecordOperation, Accumulate runs gradient computation
// immediately. It stores the results, which feed into Accumulate for future
// operations and may be fetched by calling FetchJVP. ForwardAccumulator
Expand All @@ -237,6 +258,7 @@ class ForwardAccumulator {
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);

Expand Down Expand Up @@ -930,6 +952,7 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
const ForwardFunction<Gradient>* forward_function,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (backward_tape_ != nullptr) {
Expand Down Expand Up @@ -981,23 +1004,36 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
}
}

// Avoid infinite recursion. Whichever forward function we run, it'll end up
// executing ops, and we don't want to watch those with this accumulator.
accumulating_ = true;
auto reset_accumulating =
gtl::MakeCleanup([this] { this->accumulating_ = false; });

std::vector<Gradient*> forward_grads;
TF_RETURN_IF_ERROR(
ForwardpropFromTape(output_tensors, backward_function_getter,
backward_function_deleter, in_grads, &forward_grads));

if (forward_function == nullptr) {
// We have no special-cased forward gradient. Fall back to running the
// backward function under a gradient tape.
TF_RETURN_IF_ERROR(ForwardpropFromTape(
output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads));
} else {
TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads));
}
for (int i = 0; i < forward_grads.size(); ++i) {
if (forward_grads[i] != nullptr) {
int64 tensor_id = output_tensors[i].GetID();
auto existing = accumulated_gradients_.find(tensor_id);
if (existing != accumulated_gradients_.end()) {
vspace_.DeleteGradient(existing->second);
// This is a somewhat odd case to be in, since it means we have two
// operations which supposedly both created the same Tensor. It comes up
// in recompute_grad, where the gradients have the same value. However,
// only the original gradient is connected to everything else, so we
// should still use that.
vspace_.DeleteGradient(forward_grads[i]);
} else {
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
}
accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i];
}
}
return Status::OK();
Expand Down
74 changes: 72 additions & 2 deletions tensorflow/python/eager/forwardprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,83 @@
import functools

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute

from tensorflow.python.framework import ops

from tensorflow.python.ops import array_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest


# TODO(allenl): Special-case op gradients and tf.functions to avoid unnecessary
# evaluation of gradient functions.
# TODO(allenl): experimental_relax_shapes for gradients which rely on static
# shape information may be underspecialized. We may want hand-written forward
# implementations.
@def_function.function(experimental_relax_shapes=True)
def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
"""Computes a Jacobian-vector product for an op.
Note that this function would be wasteful if executed eagerly. It runs the
backward gradient function and throws away the result just to record its
operations on a GradientTape. These unused ops are pruned away when this
function is traced.
Args:
op_name: A string, the type of operation being executed.
attr_tuple: Attributes of the operation.
inputs: A flat list of input Tensors to the operation.
outputs: A flat list of output Tensors from the operation.
tangents: A flat list of Tensors, same shape as `inputs`.
Returns:
A flat list of tangents corresponding to `outputs`.
"""
float_inputs = []
float_indices = []
nontrivial_tangents = []
for input_index, tensor in enumerate(inputs):
if tensor.dtype.is_floating:
float_inputs.append(tensor)
float_indices.append(input_index)
nontrivial_tangents.append(tangents[input_index])

with backprop.GradientTape() as transpose_tape:
with backprop.GradientTape() as backfunc_tape:
backfunc_tape.watch(float_inputs)
execute.record_gradient(op_name, inputs, attr_tuple, outputs,
"forward_op_replay")

forwardprop_aids = []
float_outputs = []
nontrivial_output_indices = []
for output_index, output in enumerate(outputs):
if output.dtype.is_floating:
forwardprop_aids.append(
array_ops.ones_like(output, name="unused_forwardprop_aid"))
float_outputs.append(output)
nontrivial_output_indices.append(output_index)

transpose_tape.watch(forwardprop_aids)
grads = backfunc_tape.gradient(
float_outputs,
float_inputs,
forwardprop_aids,
unconnected_gradients=UnconnectedGradients.ZERO)
nontrivial_output_tangents = transpose_tape.gradient(
grads, forwardprop_aids, output_gradients=nontrivial_tangents)
output_tangents = [None] * len(outputs)
for index, tangent in zip(nontrivial_output_indices,
nontrivial_output_tangents):
output_tangents[index] = tangent
return output_tangents


pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(_forward_gradient)


class ForwardGradientAccumulator(object):
"""Computes Jacobian-vector products using forward-mode autodiff.
Expand Down
78 changes: 69 additions & 9 deletions tensorflow/python/eager/forwardprop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import forwardprop
Expand All @@ -32,6 +33,7 @@
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import test
from tensorflow.python.util import nest

Expand Down Expand Up @@ -74,7 +76,10 @@ def _f(*params):
with backprop.GradientTape() as tape:
tape.watch(params)
primals_out = f(*params)
return tape.gradient(primals_out, params[argnums])
return tape.gradient(
primals_out,
params[argnums],
unconnected_gradients=UnconnectedGradients.ZERO)

return _f

Expand All @@ -93,8 +98,8 @@ def _test_gradients(testcase,
atol=1e-6):
"""Tests forward/backward jacobians of `f`'s [0, `order`)-order gradients."""
if order < 1:
raise ValueError("`order` should be a positive integer, got '{}'."
.format(order))
raise ValueError(
"`order` should be a positive integer, got '{}'.".format(order))
if order > 1:
_test_gradients(
testcase=testcase,
Expand All @@ -117,6 +122,60 @@ def _test_gradients(testcase,

class ForwardpropTest(test.TestCase):

def testForwardGradientFunction(self):
add_outputs = (constant_op.constant(4.),)
vp, = forwardprop._forward_gradient(
op_name="Add",
attr_tuple=(),
inputs=(constant_op.constant(1.), constant_op.constant(3.)),
outputs=add_outputs,
tangents=(
constant_op.constant(1.),
constant_op.constant(5.),
))
self.assertAllClose(1. + 5., self.evaluate(vp))

mul_outputs = (constant_op.constant([20.]),)
vp, = forwardprop._forward_gradient(
op_name="Mul",
attr_tuple=(),
inputs=(constant_op.constant([4.]), constant_op.constant([5.])),
outputs=mul_outputs,
tangents=(
constant_op.constant([2.]),
constant_op.constant([3.]),
))
self.assertAllClose([2. * 5. + 3. * 4.], self.evaluate(vp))

def testForwardGradientFunctionUsedByAccumulatorForOps(self):
previous_fn = forwardprop._forward_gradient
try:
with forwardprop.ForwardGradientAccumulator() as acc:
x = constant_op.constant(1.)
acc.watch(x, 2.)
y = x + x
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(
lambda *args, **kwargs: [constant_op.constant(-15.)])
z = x + x
self.assertAllClose(4., acc.jvp(y))
self.assertAllClose(-15., acc.jvp(z))
finally:
pywrap_tensorflow.TFE_Py_RegisterForwardGradientFunction(previous_fn)

@test_util.assert_no_new_pyobjects_executing_eagerly
def testFunctionCacheLimited(self):
# Every time this test is executed, it will create a slightly larger Tensor
# and push it through Add's gradient. Since we check for new pyobjects after
# the warmup, retracing each time without cleaning up old traces fails the
# test. It works because of experimental_relax_shapes.
execution_count = getattr(self, "_execution_count", 0)
self._execution_count = execution_count + 1
x = array_ops.zeros([execution_count])
with forwardprop.ForwardGradientAccumulator() as acc:
acc.watch(x, array_ops.ones_like(x))
y = x + x
self.assertAllClose(2. * array_ops.ones_like(x), acc.jvp(y))

@test_util.assert_no_new_pyobjects_executing_eagerly
def testMultipleWatchesAdd(self):
x = constant_op.constant(-2.)
Expand Down Expand Up @@ -151,14 +210,14 @@ def testDeadTensorsJVPCleared(self):
self.assertIsNone(derived_tensor_weak())
self.assertIsNone(derived_tensor_grad_weak())

@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testJVPManual(self):
primal, tangent = _jvp(math_ops.sin, (constant_op.constant(0.1),),
(constant_op.constant(0.2),))
self.assertAllClose(math_ops.sin(0.1), primal)
self.assertAllClose(math_ops.cos(0.1) * 0.2, tangent)

@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testNumericHigherOrder(self):

def f(x):
Expand All @@ -169,7 +228,7 @@ def f(x):
_test_gradients(
self, f, [constant_op.constant([[2.0, 3.0], [1.0, 4.0]])], order=3)

@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradient(self):

@custom_gradient.custom_gradient
Expand All @@ -182,7 +241,7 @@ def grad(dy):

_test_gradients(self, f, [constant_op.constant([1., 2.])], order=3)

@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testCustomGradientRecomputeGrad(self):

@custom_gradient.recompute_grad
Expand Down Expand Up @@ -257,7 +316,7 @@ def fun(x):
tangents = constant_op.constant([3., 4., 5.])
_hvp(fun, (primals,), (tangents,))

@test_util.assert_no_new_tensors
@test_util.assert_no_new_pyobjects_executing_eagerly
def testHVPCorrectness(self):

def fun(x):
Expand Down Expand Up @@ -285,6 +344,7 @@ def fun(x):
self.assertAllClose(backback_hvp, forwardback_hvp_function)


if __name__ == '__main__':
if __name__ == "__main__":
# TODO(allenl): Also test with 1.x-style graph mode.
ops.enable_eager_execution()
test.main()
8 changes: 8 additions & 0 deletions tensorflow/python/eager/pywrap_tfe.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterGradientFunction(PyObject* e);

// Registers e as the forward_gradient_function. The registered function takes
// (op_name, attrs, inputs, outputs, tangents) and returns the output
// tangents. This function is used only for operations, not for custom gradients
// or functional ops.
//
// This function is not thread-safe.
PyObject* TFE_Py_RegisterForwardGradientFunction(PyObject* e);

// Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using
// `exception` if not nullptr, else using the class registered via
// TFE_Py_RegisterExceptionClass), and returns -1.
Expand Down
Loading

0 comments on commit 68e2db3

Please sign in to comment.