Skip to content

Commit

Permalink
Merge pull request #215 from malmaud/ox/moregrads
Browse files Browse the repository at this point in the history
Allow more options for Gradient calculation
  • Loading branch information
oxinabox authored May 20, 2017
2 parents 0988ad7 + 3ab5a9f commit cbc03da
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 15 deletions.
32 changes: 27 additions & 5 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1370,15 +1370,37 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name))



node_name(::Void) = nothing
node_name(xs::AbstractVector)=node_name.(xs)

function gradients(y, x::AbstractArray)
x_names = [node_name(node) for node in x]
y_name = node_name(y)
"""
gradients(ys, xs, grad_ys=nothing)
Constructs symbolic partial derivatives of sum of ys w.r.t. x in xs.
ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys.
gradients() adds ops to the graph to output the partial derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.
`grad_ys` is a tensor or list of tensors which holds the initial gradients for each y in ys.
If `ys` is a single tensor `grad_ys` must be a single tensor; if `ys` is a list of tensors then likewise `grad_ys` must be a list of the smae size.
When `grad_ys` is `nothing`, it is effectiely defaulted to a tensor of '1's of the shape of y for each y in ys.
`grad_ys` can be partialy specified, with some gradients given and others left to default, py passing their values as `nothing`.
A user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y).
`grad_ys` must be a `Tensor` (or an `Operation`), if a plain julia value to be used, it should be wrapped into a `constant`.
see: [Python Docs](https://www.tensorflow.org/versions/master/api_docs/python/tf/gradients)
"""
function gradients(y, x::AbstractArray, grad_y=nothing)
x_names = node_name(x)
y_names = node_name(y)
grad_y_names = node_name(grad_y)
meta_graph = train.export_meta_graph()
b = IOBuffer()
writeproto(b, meta_graph)
graph_proto = @compat take!(b)
node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_name)
node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names, $grad_y_names)
extend_graph(node_protos)
out = []
for name in grad_names
Expand All @@ -1393,7 +1415,7 @@ function gradients(y, x::AbstractArray)
return out
end

gradients(y, x) = gradients(y, [x])[1]
gradients(y, x, grad_y=nothing) = gradients(y, [x], grad_y)[1]

function get_num_outputs(op::Operation)
@tfcall(:TF_OperationNumOutputs, Cint, (Ptr{Void},), op.ptr) |> Int
Expand Down
23 changes: 14 additions & 9 deletions src/py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,28 @@ function make_py_graph(graph_proto)
end

function to_protos(py_graph)
n_nodes = length(py_graph[:node])
py_graph_def = py_graph[:as_graph_def]()
n_nodes = length(py_graph_def[:node])
protos = []
for node_idx in 1:n_nodes
node_py = py_graph[:node][node_idx]
node_py = py_graph_def[:node][node_idx]
proto = Vector{UInt8}(node_py[:SerializeToString]())
push!(protos, proto)
end
return protos
end

function py_gradients(jl_graph_proto, x_names, y_name)
function py_gradients(jl_graph_proto, x_names, y_names, grad_y_names)
py_graph = make_py_graph(jl_graph_proto)
to_py_node = node_name->py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1))
py_x = [to_py_node(node) for node in x_names]
py_y = to_py_node(y_name)
@py_catch grad_node = py_tf[][:gradients](py_y, py_x)
py_graph_def = py_graph[:as_graph_def]()

to_py_node(node_name) = py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1))
to_py_node(node_names::AbstractVector) = to_py_node.(node_names) # Boardcast via dispatch
to_py_node(::Void) = nothing

py_x = to_py_node(x_names)
py_y = to_py_node(y_names)
py_grad_y = to_py_node(grad_y_names)
@py_catch grad_node = py_tf[][:gradients](py_y, py_x, py_grad_y)
grad_names = []
for node in grad_node
if node === nothing
Expand All @@ -88,7 +93,7 @@ function py_gradients(jl_graph_proto, x_names, y_name)
push!(grad_names, node[:name])
end
end
return to_protos(py_graph_def), grad_names
return to_protos(py_graph), grad_names
end

const events_writer = Ref{PyObject}()
Expand Down
25 changes: 24 additions & 1 deletion test/core.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using Base.Test
using TensorFlow


@testset "Graph importing" begin
if tf_version() >= v"1.0.0-rc1"
graph_pb = read(joinpath(dirname(@__FILE__), "graph.pb"))
Expand Down Expand Up @@ -97,3 +96,27 @@ end
end
end
end


@testset "Gradients" begin
let
sess = Session(Graph())
A = get_variable("A", (1,), Float32)
B = get_variable("B", (1,), Float32)

@test [[2.0f0]] == run(sess, gradients(2A, [A]))
@test [2.0f0] == run(sess, gradients(2A, A))

@test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B]))
@test [[8.0f0]] == run(sess, gradients([3A, 5A], [A]))

@test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B]))

@test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0])))
@test [68.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), constant([11.0f0])]))
@test [38.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), nothing]))

end

end

0 comments on commit cbc03da

Please sign in to comment.