Skip to content

Allow for unthunk to return nothing #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Apr 9, 2025

Aims to fix FluxML/Zygote.jl#1567

In rules defined by @adjoint, there is always a second method back(::Nothing) = nothing, so that the method you write need not allow for nothing. However, the way #17 added unthunk means that if this returns nothing, it does not cause this shortcut.

Making instead a separate method back(Δ::AbstractThunk) = back(unthunk_tangent(Δ)) should avoid that. It assumes that (eventually) unthunk_tangent must give us a non-thunk.

Cc @oschulz and @pxl-th, for work on FluxML/Zygote.jl#966

@oschulz
Copy link
Contributor

oschulz commented Apr 9, 2025

Sounds sensible to me, but I'm not sure I can judge all implications across Zygote's code.

@mcabbott
Copy link
Member Author

mcabbott commented Apr 9, 2025

Besides inference failures, the one failing test is this:

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ 2 + y for x in xs, y in xs])
       end
([20.0 28.0; 36.0 44.0],)

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ i for (i, x) in enumerate(xs)])
       end
([1.0 27.0; 8.0 500.0],)

julia> gradient([2 3; 4 5]) do xs
           sum([x ^ i + y for (i, x) in enumerate(xs), y in xs])
       end == ([8 112; 36 2004],)
ERROR: MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126

Stacktrace:
  [1] cvt1
    @ ./essentials.jl:612 [inlined]
  [2] ntuple
    @ ./ntuple.jl:49 [inlined]
  [3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
    @ Base ./essentials.jl:614
  [4] setindex!
    @ ./array.jl:994 [inlined]
  [5] setindex!
    @ ./multidimensional.jl:704 [inlined]
  [6] macro expansion
    @ ./reducedim.jl:289 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{…}, 4}, A::Array{Tuple{…}, 4})
    @ Base ./reducedim.jl:287
  [9] mapreducedim!
    @ ./reducedim.jl:296 [inlined]
 [10] _mapreduce_dim
    @ ./reducedim.jl:340 [inlined]
 [11] mapreduce
    @ ./reducedim.jl:329 [inlined]
 [12] #742
    @ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
 [13] map
    @ ./tuple.jl:406 [inlined]
 [14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{…}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{…}, Float64}, 4})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
 [15] product_pullback
    @ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
 [16] #3284#back
    @ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
 [17] #17
    @ ./REPL[8]:2 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Int64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Int64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [20] gradient(f::Function, args::Matrix{Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:154
 [21] top-level scope
    @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> show(err)
1-element ExceptionStack:
MethodError: Cannot `convert` an object of type Float64 to an object of type ChainRulesCore.ZeroTangent
The function `convert` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  convert(::Type{T}, ::T) where T
   @ Base Base.jl:126

Stacktrace:
  [1] cvt1
    @ ./essentials.jl:612 [inlined]
  [2] ntuple
    @ ./ntuple.jl:49 [inlined]
  [3] convert(::Type{Tuple{ChainRulesCore.ZeroTangent, Float64}}, x::Tuple{Float64, Float64})
    @ Base ./essentials.jl:614
  [4] setindex!
    @ ./array.jl:994 [inlined]
  [5] setindex!
    @ ./multidimensional.jl:704 [inlined]
  [6] macro expansion
    @ ./reducedim.jl:289 [inlined]
  [7] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [8] _mapreducedim!(f::Zygote.StaticGetter{1}, op::typeof(Zygote.accum), R::Array{Tuple{ChainRulesCore.ZeroTangent, Float64}, 4}, A::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
    @ Base ./reducedim.jl:287
  [9] mapreducedim!
    @ ./reducedim.jl:296 [inlined]
 [10] _mapreduce_dim
    @ ./reducedim.jl:340 [inlined]
 [11] mapreduce
    @ ./reducedim.jl:329 [inlined]
 [12] #742
    @ ~/.julia/dev/Zygote/src/lib/array.jl:287 [inlined]
 [13] map
    @ ./tuple.jl:406 [inlined]
 [14] productfunc(xs::Tuple{Base.Iterators.Enumerate{Matrix{Int64}}, Matrix{Int64}}, dy::Array{Tuple{Tuple{ChainRulesCore.Thunk{ChainRules.var"#382#416"{Float64, Int64, Int64, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Int64}}, Float64}, Float64}, 4})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:282
 [15] product_pullback
    @ ~/.julia/dev/Zygote/src/lib/array.jl:295 [inlined]
 [16] #3284#back
    @ ~/.julia/dev/ZygoteRules/src/adjoint.jl:73 [inlined]
...

The offending code is here:

https://github.com/FluxML/Zygote.jl/blob/1b914d994aea236bcb6d3d0cd6c099d86cede101/src/lib/array.jl#L286-L287

And the problem is that zero(::Thunk) isa ZeroTangent:

julia> using ChainRulesCore

julia> @thunk 1+1
Thunk(var"#21#22"())

julia> zero(ans)
ZeroTangent()

although it's not clear to me why this PR exposes that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gradient of reshape(::Array{Bool}, ...) does not handle thunks
2 participants