From 3ab5a9fb9c1a201a581d2298222aafd4390f3387 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 10 May 2017 17:04:57 +0800 Subject: [PATCH] =doc --- src/core.jl | 18 ++++++++++++++++++ test/core.jl | 2 ++ 2 files changed, 20 insertions(+) diff --git a/src/core.jl b/src/core.jl index 89bfe919..ebfafb88 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1373,7 +1373,25 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name)) node_name(::Void) = nothing node_name(xs::AbstractVector)=node_name.(xs) +""" +gradients(ys, xs, grad_ys=nothing) + +Constructs symbolic partial derivatives of sum of ys w.r.t. x in xs. + +ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys. + +gradients() adds ops to the graph to output the partial derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys. +`grad_ys` is a tensor or list of tensors which holds the initial gradients for each y in ys. +If `ys` is a single tensor `grad_ys` must be a single tensor; if `ys` is a list of tensors then likewise `grad_ys` must be a list of the smae size. +When `grad_ys` is `nothing`, it is effectiely defaulted to a tensor of '1's of the shape of y for each y in ys. +`grad_ys` can be partialy specified, with some gradients given and others left to default, py passing their values as `nothing`. +A user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). + +`grad_ys` must be a `Tensor` (or an `Operation`), if a plain julia value to be used, it should be wrapped into a `constant`. + +see: [Python Docs](https://www.tensorflow.org/versions/master/api_docs/python/tf/gradients) +""" function gradients(y, x::AbstractArray, grad_y=nothing) x_names = node_name(x) y_names = node_name(y) diff --git a/test/core.jl b/test/core.jl index ce3b7845..6f33a399 100644 --- a/test/core.jl +++ b/test/core.jl @@ -113,6 +113,8 @@ end @test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B])) @test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0]))) + @test [68.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), constant([11.0f0])])) + @test [38.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), nothing])) end