Skip to content

Commit

Permalink
=allow passing in the gradient for y #212
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyndon White committed May 10, 2017
1 parent 8a92008 commit 63cf980
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 27 deletions.
11 changes: 7 additions & 4 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1370,16 +1370,19 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name))



node_name(xs::AbstractVector) = node_name.(xs)
node_name(::Void) = nothing
node_name(xs::AbstractVector)=node_name.(xs)

function gradients(y, x::AbstractArray)

function gradients(y, x::AbstractArray, grad_y=nothing)
x_names = node_name(x)
y_names = node_name(y)
grad_y_names = node_name(grad_y)
meta_graph = train.export_meta_graph()
b = IOBuffer()
writeproto(b, meta_graph)
graph_proto = @compat take!(b)
node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names)
node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names, $grad_y_names)
extend_graph(node_protos)
out = []
for name in grad_names
Expand All @@ -1394,7 +1397,7 @@ function gradients(y, x::AbstractArray)
return out
end

gradients(y, x) = gradients(y, [x])[1]
gradients(y, x, grad_y=nothing) = gradients(y, [x], grad_y)[1]

function get_num_outputs(op::Operation)
@tfcall(:TF_OperationNumOutputs, Cint, (Ptr{Void},), op.ptr) |> Int
Expand Down
7 changes: 4 additions & 3 deletions src/py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,17 @@ function to_protos(py_graph)
return protos
end

function py_gradients(jl_graph_proto, x_names, y_names)
function py_gradients(jl_graph_proto, x_names, y_names, grad_y_names)
py_graph = make_py_graph(jl_graph_proto)

to_py_node(node_name) = py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1))
to_py_node(node_names::AbstractVector) = to_py_node.(node_names) # Boardcast via dispatch
to_py_node(::Void) = nothing

py_x = to_py_node(x_names)
py_y = to_py_node(y_names)

@py_catch grad_node = py_tf[][:gradients](py_y, py_x)
py_grad_y = to_py_node(grad_y_names)
@py_catch grad_node = py_tf[][:gradients](py_y, py_x, py_grad_y)
grad_names = []
for node in grad_node
if node === nothing
Expand Down
40 changes: 20 additions & 20 deletions test/core.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,6 @@
using Base.Test
using TensorFlow

@testset "Gradients" begin
let
sess = Session(Graph())
A = get_variable("A", (1,), Float32)
B = get_variable("B", (1,), Float32)

@test [[2.0f0]] == run(sess, gradients(2A, [A]))
@test [2.0f0] == run(sess, gradients(2A, A))

@test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B]))
@test [[8.0f0]] == run(sess, gradients([3A, 5A], [A]))

@test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B]))

@test [35.0f0] == run(sess, gradients(7A, A, 14))

end

end

@testset "Graph importing" begin
if tf_version() >= v"1.0.0-rc1"
graph_pb = read(joinpath(dirname(@__FILE__), "graph.pb"))
Expand Down Expand Up @@ -118,3 +98,23 @@ end
end


@testset "Gradients" begin
let
sess = Session(Graph())
A = get_variable("A", (1,), Float32)
B = get_variable("B", (1,), Float32)

@test [[2.0f0]] == run(sess, gradients(2A, [A]))
@test [2.0f0] == run(sess, gradients(2A, A))

@test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B]))
@test [[8.0f0]] == run(sess, gradients([3A, 5A], [A]))

@test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B]))

@test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0])))

end

end

0 comments on commit 63cf980

Please sign in to comment.