diff --git a/src/core.jl b/src/core.jl index 06f9b16b..89bfe919 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1370,16 +1370,19 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name)) -node_name(xs::AbstractVector) = node_name.(xs) +node_name(::Void) = nothing +node_name(xs::AbstractVector)=node_name.(xs) -function gradients(y, x::AbstractArray) + +function gradients(y, x::AbstractArray, grad_y=nothing) x_names = node_name(x) y_names = node_name(y) + grad_y_names = node_name(grad_y) meta_graph = train.export_meta_graph() b = IOBuffer() writeproto(b, meta_graph) graph_proto = @compat take!(b) - node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names) + node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names, $grad_y_names) extend_graph(node_protos) out = [] for name in grad_names @@ -1394,7 +1397,7 @@ function gradients(y, x::AbstractArray) return out end -gradients(y, x) = gradients(y, [x])[1] +gradients(y, x, grad_y=nothing) = gradients(y, [x], grad_y)[1] function get_num_outputs(op::Operation) @tfcall(:TF_OperationNumOutputs, Cint, (Ptr{Void},), op.ptr) |> Int diff --git a/src/py.jl b/src/py.jl index 352d1ee0..84a766f0 100644 --- a/src/py.jl +++ b/src/py.jl @@ -70,16 +70,17 @@ function to_protos(py_graph) return protos end -function py_gradients(jl_graph_proto, x_names, y_names) +function py_gradients(jl_graph_proto, x_names, y_names, grad_y_names) py_graph = make_py_graph(jl_graph_proto) to_py_node(node_name) = py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1)) to_py_node(node_names::AbstractVector) = to_py_node.(node_names) # Boardcast via dispatch + to_py_node(::Void) = nothing py_x = to_py_node(x_names) py_y = to_py_node(y_names) - - @py_catch grad_node = py_tf[][:gradients](py_y, py_x) + py_grad_y = to_py_node(grad_y_names) + @py_catch grad_node = py_tf[][:gradients](py_y, py_x, py_grad_y) grad_names = [] for node in grad_node if node === nothing diff --git a/test/core.jl b/test/core.jl index d4d43ffb..ce3b7845 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1,26 +1,6 @@ using Base.Test using TensorFlow -@testset "Gradients" begin - let - sess = Session(Graph()) - A = get_variable("A", (1,), Float32) - B = get_variable("B", (1,), Float32) - - @test [[2.0f0]] == run(sess, gradients(2A, [A])) - @test [2.0f0] == run(sess, gradients(2A, A)) - - @test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B])) - @test [[8.0f0]] == run(sess, gradients([3A, 5A], [A])) - - @test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B])) - - @test [35.0f0] == run(sess, gradients(7A, A, 14)) - - end - -end - @testset "Graph importing" begin if tf_version() >= v"1.0.0-rc1" graph_pb = read(joinpath(dirname(@__FILE__), "graph.pb")) @@ -118,3 +98,23 @@ end end +@testset "Gradients" begin + let + sess = Session(Graph()) + A = get_variable("A", (1,), Float32) + B = get_variable("B", (1,), Float32) + + @test [[2.0f0]] == run(sess, gradients(2A, [A])) + @test [2.0f0] == run(sess, gradients(2A, A)) + + @test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B])) + @test [[8.0f0]] == run(sess, gradients([3A, 5A], [A])) + + @test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B])) + + @test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0]))) + + end + +end +