Skip to content

Commit

Permalink
Define adjoints for inverse of PlanarLayer (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 11, 2021
1 parent 827b80a commit 3e0d765
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 30 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.12"
version = "0.8.13"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -19,6 +20,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ArgCheck = "1, 2"
ChainRulesCore = "0.9"
Compat = "3"
Distributions = "0.23.3, 0.24"
MappedArrays = "0.2.2, 0.3"
Expand Down
2 changes: 2 additions & 0 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular
import NonlinearSolve
import ChainRulesCore

export TransformDistribution,
PositiveDistribution,
Expand Down Expand Up @@ -243,6 +244,7 @@ end

include("utils.jl")
include("interface.jl")
include("chainrules.jl")

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
Expand Down
11 changes: 7 additions & 4 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,16 @@ D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows.
arXiv:1505.05770
"""
function find_alpha(wt_y::Real, wt_u_hat::Real, b::Real)
# Compute the initial bracket
_wt_y, _wt_u_hat, _b = promote(wt_y, wt_u_hat, b)
initial_bracket = (_wt_y - abs(_wt_u_hat), _wt_y + abs(_wt_u_hat))
# avoid promotions in root-finding algorithm and simplify AD dispatches
return find_alpha(promote(wt_y, wt_u_hat, b)...)
end
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
# Compute the initial bracket (see above).
initial_bracket = (wt_y - abs(wt_u_hat), wt_y + abs(wt_u_hat))

# Try to solve the root-finding problem, i.e., compute a final bracket
prob = NonlinearSolve.NonlinearProblem{false}(initial_bracket) do α, _
α + _wt_u_hat * tanh+ _b) - _wt_y
α + wt_u_hat * tanh+ b) - wt_y
end
sol = NonlinearSolve.solve(prob, NonlinearSolve.Falsi())
if sol.retcode === NonlinearSolve.MAXITERS_EXCEED
Expand Down
8 changes: 8 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# differentation rule for the iterative algorithm in the inverse of `PlanarLayer`
ChainRulesCore.@scalar_rule(
find_alpha(wt_y::Real, wt_u_hat::Real, b::Real),
@setup(
x = inv(1 + wt_u_hat * sech+ b)^2),
),
(x, - tanh+ b) * x, x - 1),
)
13 changes: 13 additions & 0 deletions src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,5 +181,18 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end

function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
return track(find_alpha, wt_y, wt_u_hat, b)
end
@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal)
α = find_alpha(data(wt_y), data(wt_u_hat), data(b))

∂wt_y = inv(1 + wt_u_hat * sech+ b)^2)
∂wt_u_hat = - tanh+ b) * ∂wt_y
∂b = ∂wt_y - 1
find_alpha_pullback::Real) =* ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b)

return α, find_alpha_pullback
end

end
14 changes: 14 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,17 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)

return z, pullback_link_chol_lkj
end

function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal}
return track(find_alpha, wt_y, wt_u_hat, b)
end
@grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal)
α = find_alpha(data(wt_y), data(wt_u_hat), data(b))

∂wt_y = inv(1 + wt_u_hat * sech+ b)^2)
∂wt_u_hat = - tanh+ b) * ∂wt_y
∂b = ∂wt_y - 1
find_alpha_pullback::Real) =* ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b)

return α, find_alpha_pullback
end
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ has a closed-form implementation.
Most bijectors have closed-form evaluations, but there are cases where
this is not the case. For example the *inverse* evaluation of `PlanarLayer`
requires an iterative procedure to evaluate and thus is not differentiable.
requires an iterative procedure to evaluate.
"""
isclosedform(b::Bijector) = true

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand All @@ -12,6 +13,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesTestUtils = "0.5"
Combinatorics = "1.0.2"
DistributionsAD = "0.6.3"
FiniteDifferences = "0.11, 0.12"
Expand Down
10 changes: 10 additions & 0 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testset "chainrules" begin
x, Δx, x̄ = randn(3)
y, Δy, ȳ = randn(3)
z, Δz, z̄ = randn(3)
Δu = randn()

= expm1(y)
frule_test(Bijectors.find_alpha, (x, Δx), (ỹ, Δy), (z, Δz); rtol=1e-3, atol=1e-3)
rrule_test(Bijectors.find_alpha, Δu, (x, x̄), (ỹ, ȳ), (z, z̄); rtol=1e-3, atol=1e-3)
end
25 changes: 25 additions & 0 deletions test/ad/flows.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testset "PlanarLayer" begin
# logpdf of a flow with a planar layer and two-dimensional inputs
test_ad(randn(7)) do θ
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(2, 1), layer)
return logpdf_forward(flow, θ[6:7])
end
test_ad(randn(11)) do θ
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(2, 1), layer)
return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :)))
end

# logpdf of a flow with the inverse of a planar layer and two-dimensional inputs
test_ad(randn(7)) do θ
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(2, 1), inv(layer))
return logpdf_forward(flow, θ[6:7])
end
test_ad(randn(11)) do θ
layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
flow = transformed(MvNormal(2, 1), inv(layer))
return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :)))
end
end
32 changes: 10 additions & 22 deletions test/bijectors/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,10 @@ function test_bijector(
test_bijector_reals(b, x_true, y_true, logjac_true; kwargs...)

# Test AD
if isclosedform(b)
test_ad(x -> b(first(x)), [x_true, ])
end
test_ad(x -> b(first(x)), [x_true, ])

if isclosedform(ib)
y = b(x_true)
test_ad(x -> ib(first(x)), [y, ])
end
y = b(x_true)
test_ad(x -> ib(first(x)), [y, ])

test_ad(x -> logabsdetjac(b, first(x)), [x_true, ])
end
Expand Down Expand Up @@ -167,28 +163,20 @@ function test_bijector(
test_bijector_arrays(b, collect(x_true), collect(y_true), logjac_true; kwargs...)

# Test AD
if isclosedform(b)
test_ad(x -> sum(b(x)), collect(x_true))
end
if isclosedform(ib)
y = b(x_true)
test_ad(x -> sum(ib(x)), y)
end
test_ad(x -> sum(b(x)), collect(x_true))
y = b(x_true)
test_ad(x -> sum(ib(x)), y)

test_ad(x -> logabsdetjac(b, x), x_true)
end
end

function test_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix; tol=1e-6)
if isclosedform(b)
logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)]
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
end
logjac_ad = [logabsdet(ForwardDiff.jacobian(b, x))[1] for x in eachcol(xs)]
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
end

function test_logabsdetjac(b::Bijector{0}, xs::AbstractVector; tol=1e-6)
if isclosedform(b)
logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs]
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
end
logjac_ad = [log(abs(ForwardDiff.derivative(b, x))) for x in xs]
@test mean(logabsdetjac(b, xs) - logjac_ad) tol
end
4 changes: 2 additions & 2 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ end
our_method = sum(forward(flow, z).logabsdetjac)

@test our_method forward_diff
@test inv(flow)(flow(z)) z rtol=0.25
@test (inv(flow) flow)(z) z rtol=0.25
@test inv(flow)(flow(z)) z
@test (inv(flow) flow)(z) z
end

w = ones(10)
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Bijectors

using ChainRulesTestUtils
using Combinatorics
using DistributionsAD
using FiniteDifferences
Expand Down Expand Up @@ -35,6 +36,8 @@ if GROUP == "All" || GROUP == "Interface"
end

if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/flows.jl")
include("ad/distributions.jl")
end

2 comments on commit 3e0d765

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/27784

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.13 -m "<description of version>" 3e0d76562d0aa8b1e2137d622c0ae95b3325310e
git push origin v0.8.13

Please sign in to comment.