From 1b111d8a30d790b99d61574a549fa050c912af7f Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Tue, 21 Jan 2025 09:00:17 +0200 Subject: [PATCH] Unthunk tangents (if any) before returning gradient (#1551) --- Project.toml | 2 +- src/compiler/chainrules.jl | 11 +++++------ src/compiler/interface.jl | 4 ++-- src/lib/lib.jl | 3 +++ test/chainrules.jl | 28 ++++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 9c284634b..260eb6903 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ Requires = "1.1" SpecialFunctions = "1.6, 2" Statistics = "1" Tracker = "0.2" -ZygoteRules = "0.2.5" +ZygoteRules = "0.2.7" julia = "1.6" [extras] diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c3bb9e208..e0e09a63b 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,11 +1,10 @@ # ToDo: Move some of this to ZygoteRules, or move unthunk_tangent for Tuple and NamedTuple from # Zygote rules here? -function unthunk_tangent end -@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) -@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x -@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x -@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) -unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) +@inline ZygoteRules.unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) +@inline ZygoteRules.unthunk_tangent(x::NTuple{N,<:Number}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x +@inline ZygoteRules.unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x) +ZygoteRules.unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d]) @non_differentiable unthunk_tangent(::IdDict) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 8f251d761..a5da774a8 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -152,7 +152,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - return _project_all(args, grad) + return _project_all(args, unthunk_tangent(grad)) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -218,7 +218,7 @@ function withgradient(f, args...) else back(sensitivity(y)) end - results = _project_all(args, grad) + results = _project_all(args, unthunk_tangent(grad)) (val=y, grad=results) end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 179951033..90e596d95 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -40,6 +40,9 @@ end accum(x::NamedTuple, y::ChainRulesCore.Tangent) = accum(x, wrap_chainrules_output(y)) accum(x::ChainRulesCore.Tangent, y::NamedTuple) = accum(wrap_chainrules_output(x), y) +accum(x::Nothing, y::AbstractThunk) = y +accum(x::AbstractThunk, y::Nothing) = x + accum(x, y::AbstractThunk) = @thunk(accum(x, unthunk(y))) accum(x::AbstractThunk, y) = @thunk(accum(unthunk(x), y)) accum(x::AbstractThunk, y::AbstractThunk) = @thunk(accum(unthunk(x), unthunk(y))) diff --git a/test/chainrules.jl b/test/chainrules.jl index 3017a9e18..ed8e98b94 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -428,3 +428,31 @@ end @test Zygote.wrap_chainrules_input([[2.0; 4.0], [1.0; 3.0]]) == [[2.0; 4.0], [1.0; 3.0]] @test Zygote.wrap_chainrules_input([nothing; 4.0]) == [0.0; 4.0] # ChainRules uses the numeric zero where possible end + +@testset "Lazy" begin + custom_add(x, y) = x + y + function ChainRulesCore.rrule(::typeof(custom_add), x, y) + function pullback(Δ) + return NoTangent(), unthunk(Δ), @thunk(error("Should not compute.")) + end + custom_add(x, y), pullback + end + + x, y = 1f0, 1f0 + Zygote.gradient(x) do x + sum(custom_add(x, y)) + end +end + +@testset "No thunks in the gradient" begin + struct CustomDense + w::Matrix{Float32} + end + (d::CustomDense)(x) = d.w * x + + layers = [CustomDense(rand(Float32, 3, 3))] + x = ones(Float32, 3) + g = gradient(layers -> sum(layers[1](x)), layers)[1] + @test g[1] isa NamedTuple + @test g[1].w isa Array +end