Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
simeonschaub committed Oct 12, 2021
1 parent a3f8dc4 commit 3d56f79
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ end
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)

Expand All @@ -175,10 +176,10 @@ end
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)

Zygote._pullback(::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
Zygote._pullback(::Zygote.AContext, ::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
y, back = pullback(x -> x.m, g)
@test_broken y == 3getfield(g, :m)
@test_broken back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)
@test y == 3getfield(g, :m)
@test back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)


Gaussian = _Gaussian(:rrule)
Expand All @@ -189,6 +190,12 @@ end
y, back = pullback(x -> x.m, g)
@test y == 4getfield(g, :m)
@test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)

Gaussian = _Gaussian(:bitstype)
g = Gaussian(randn(), randn())
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
@test @inferred(back(1.0)) == ((m = 1.0, P = nothing),)
end

# issue 897
Expand Down

0 comments on commit 3d56f79

Please sign in to comment.