diff --git a/src/core.jl b/src/core.jl index 3a4b5265..ebfafb88 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1370,15 +1370,37 @@ Base.haskey(graph::Graph, name) = isnull(get_node_by_name(graph, name)) +node_name(::Void) = nothing +node_name(xs::AbstractVector)=node_name.(xs) -function gradients(y, x::AbstractArray) - x_names = [node_name(node) for node in x] - y_name = node_name(y) +""" +gradients(ys, xs, grad_ys=nothing) + +Constructs symbolic partial derivatives of sum of ys w.r.t. x in xs. + +ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys. + +gradients() adds ops to the graph to output the partial derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys. + +`grad_ys` is a tensor or list of tensors which holds the initial gradients for each y in ys. +If `ys` is a single tensor `grad_ys` must be a single tensor; if `ys` is a list of tensors then likewise `grad_ys` must be a list of the smae size. +When `grad_ys` is `nothing`, it is effectiely defaulted to a tensor of '1's of the shape of y for each y in ys. +`grad_ys` can be partialy specified, with some gradients given and others left to default, py passing their values as `nothing`. +A user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). + +`grad_ys` must be a `Tensor` (or an `Operation`), if a plain julia value to be used, it should be wrapped into a `constant`. + +see: [Python Docs](https://www.tensorflow.org/versions/master/api_docs/python/tf/gradients) +""" +function gradients(y, x::AbstractArray, grad_y=nothing) + x_names = node_name(x) + y_names = node_name(y) + grad_y_names = node_name(grad_y) meta_graph = train.export_meta_graph() b = IOBuffer() writeproto(b, meta_graph) graph_proto = @compat take!(b) - node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_name) + node_protos, grad_names = @py_proc py_gradients($graph_proto, $x_names, $y_names, $grad_y_names) extend_graph(node_protos) out = [] for name in grad_names @@ -1393,7 +1415,7 @@ function gradients(y, x::AbstractArray) return out end -gradients(y, x) = gradients(y, [x])[1] +gradients(y, x, grad_y=nothing) = gradients(y, [x], grad_y)[1] function get_num_outputs(op::Operation) @tfcall(:TF_OperationNumOutputs, Cint, (Ptr{Void},), op.ptr) |> Int diff --git a/src/py.jl b/src/py.jl index a00f31ed..84a766f0 100644 --- a/src/py.jl +++ b/src/py.jl @@ -59,23 +59,28 @@ function make_py_graph(graph_proto) end function to_protos(py_graph) - n_nodes = length(py_graph[:node]) + py_graph_def = py_graph[:as_graph_def]() + n_nodes = length(py_graph_def[:node]) protos = [] for node_idx in 1:n_nodes - node_py = py_graph[:node][node_idx] + node_py = py_graph_def[:node][node_idx] proto = Vector{UInt8}(node_py[:SerializeToString]()) push!(protos, proto) end return protos end -function py_gradients(jl_graph_proto, x_names, y_name) +function py_gradients(jl_graph_proto, x_names, y_names, grad_y_names) py_graph = make_py_graph(jl_graph_proto) - to_py_node = node_name->py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1)) - py_x = [to_py_node(node) for node in x_names] - py_y = to_py_node(y_name) - @py_catch grad_node = py_tf[][:gradients](py_y, py_x) - py_graph_def = py_graph[:as_graph_def]() + + to_py_node(node_name) = py_graph[:get_tensor_by_name](string(node_name[1], ":", node_name[2]-1)) + to_py_node(node_names::AbstractVector) = to_py_node.(node_names) # Boardcast via dispatch + to_py_node(::Void) = nothing + + py_x = to_py_node(x_names) + py_y = to_py_node(y_names) + py_grad_y = to_py_node(grad_y_names) + @py_catch grad_node = py_tf[][:gradients](py_y, py_x, py_grad_y) grad_names = [] for node in grad_node if node === nothing @@ -88,7 +93,7 @@ function py_gradients(jl_graph_proto, x_names, y_name) push!(grad_names, node[:name]) end end - return to_protos(py_graph_def), grad_names + return to_protos(py_graph), grad_names end const events_writer = Ref{PyObject}() diff --git a/test/core.jl b/test/core.jl index 08c50947..6f33a399 100644 --- a/test/core.jl +++ b/test/core.jl @@ -1,7 +1,6 @@ using Base.Test using TensorFlow - @testset "Graph importing" begin if tf_version() >= v"1.0.0-rc1" graph_pb = read(joinpath(dirname(@__FILE__), "graph.pb")) @@ -97,3 +96,27 @@ end end end end + + +@testset "Gradients" begin + let + sess = Session(Graph()) + A = get_variable("A", (1,), Float32) + B = get_variable("B", (1,), Float32) + + @test [[2.0f0]] == run(sess, gradients(2A, [A])) + @test [2.0f0] == run(sess, gradients(2A, A)) + + @test [[3.0f0], [5.0f0]] == run(sess, gradients(3A+5B, [A, B])) + @test [[8.0f0]] == run(sess, gradients([3A, 5A], [A])) + + @test [[9.0f0], [3.0f0]] == run(sess, gradients([2A+3B, 7A], [A, B])) + + @test [35.0f0] == run(sess, gradients(7A, A, constant([5.0f0]))) + @test [68.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), constant([11.0f0])])) + @test [38.0f0] == run(sess, gradients([7A,3A], A, [constant([5.0f0]), nothing])) + + end + +end +