From 927ee27e6010c64900a10f3425db47d1336d07d1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 30 Jul 2022 18:31:17 -0400 Subject: [PATCH] fix nothing in mutable struct getindex/getfield --- src/lib/lib.jl | 10 ++++----- test/features.jl | 57 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index f11a74214..5d36ef8b2 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -107,7 +107,7 @@ using Base: tail @adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} val = xs[i] function back(Δ) - accum_param(__context__, val, Δ) === nothing && return + accum_param(__context__, val, Δ) return ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing end val, back @@ -116,7 +116,7 @@ end @adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N val = xs[i] function back(Δ) - accum_param(__context__, val, Δ) === nothing && return + accum_param(__context__, val, Δ) return ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing end return val, back @@ -228,8 +228,8 @@ end @adjoint function literal_getfield(x, ::Val{f}) where f val = getfield(x, f) - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return + function literal_getfield_back(Δ) + accum_param(__context__, val, Δ) if isimmutable(x) dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...) (_project(x, dx), nothing) @@ -239,7 +239,7 @@ end return (dx,nothing) end end - unwrap(val), back + unwrap(val), literal_getfield_back end _pullback(cx::AContext, ::typeof(getfield), x, field_name::Symbol) = diff --git a/test/features.jl b/test/features.jl index cdfe7329e..8d668de71 100644 --- a/test/features.jl +++ b/test/features.jl @@ -476,7 +476,7 @@ end @test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) @test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20 - @test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0 + @test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41 # Array of mutables: @test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3] @@ -490,6 +490,61 @@ end @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],) end +@testset "mutable accum_param bugs" begin + + mutable struct Mut{T}; x::T; end + struct Imm{T}; x::T; end + + # Indexing a tuple containing a mutable struct gave `nothing` + x1 = (Mut(3.0),) + x2 = (Imm(3.0),) + x3 = (Ref(3.0),) + @test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),) + @test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + i1 = 1 + @test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),) + @test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41 + + @test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)] + @test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41 + + # When `getfield` returns a mutable struct, it gave `nothing`: + x4 = Imm(Mut(4.0)) + x5 = Mut(Mut(4.0)) + x6 = Imm(Imm(4.0)) + @test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41 + @test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0 + @test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41 + + @test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41 + @test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 + @test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41 + + # Check when using implicit parameters, Params cases used to pass: + y1 = [3.0] + y2 = (Mut(y1),) + y3 = (Imm(y1),) + @test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0] + @test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),) + @test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0] + + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] + @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) + @test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] + + i1 = 1 + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41 + @test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0] + @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),) + @test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0] + +end + @testset "NamedTuples" begin @test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),) @test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)