Skip to content

Commit

Permalink
=doc
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyndon White committed May 10, 2017
1 parent 63cf980 commit 3ab5a9f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,25 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name))
node_name(::Void) = nothing
node_name(xs::AbstractVector)=node_name.(xs)

"""
gradients(ys, xs, grad_ys=nothing)
Constructs symbolic partial derivatives of sum of ys w.r.t. x in xs.
ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys.
gradients() adds ops to the graph to output the partial derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.
`grad_ys` is a tensor or list of tensors which holds the initial gradients for each y in ys.
If `ys` is a single tensor `grad_ys` must be a single tensor; if `ys` is a list of tensors then likewise `grad_ys` must be a list of the smae size.
When `grad_ys` is `nothing`, it is effectiely defaulted to a tensor of '1's of the shape of y for each y in ys.
`grad_ys` can be partialy specified, with some gradients given and others left to default, py passing their values as `nothing`.
A user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y).
`grad_ys` must be a `Tensor` (or an `Operation`), if a plain julia value to be used, it should be wrapped into a `constant`.
see: [Python Docs](https://www.tensorflow.org/versions/master/api_docs/python/tf/gradients)
"""
function gradients(y, x::AbstractArray, grad_y=nothing)
x_names = node_name(x)
y_names = node_name(y)
Expand Down
2 changes: 2 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ end
@test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B]))

@test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0])))
@test [68.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), constant([11.0f0])]))
@test [38.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), nothing]))

end

Expand Down

0 comments on commit 3ab5a9f

Please sign in to comment.