Skip to content

Add custom rrule to speed up AD with mapped functions #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 10, 2021
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
name = "ChangesOfVariables"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.1"
version = "0.1.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1"
julia = "1"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[targets]
test = ["Documenter", "ForwardDiff"]
test = ["ChainRulesTestUtils", "Documenter", "ForwardDiff"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ changes for functions that perform a change of variables (like coordinate
transformations).

`ChangesOfVariables` is a very lightweight package and has no dependencies
beyond `Base`, `LinearAlgebra` and `Test`.
beyond `Base`, `LinearAlgebra`, `Test` and `ChainRulesCore`.

## Documentation

Expand Down
1 change: 1 addition & 0 deletions src/ChangesOfVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ transformations).
"""
module ChangesOfVariables

using ChainRulesCore
using LinearAlgebra
using Test

Expand Down
27 changes: 20 additions & 7 deletions src/with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,30 @@ export with_logabsdet_jacobian
end


@inline _get_y(y_with_ladj::NTuple{2,Any,}) = y_with_ladj[1]
@inline _get_ladj(y_with_ladj::NTuple{2,Any}) = y_with_ladj[2]

_with_ladj_on_mapped(map_or_bc::Function, y_with_ladj::Tuple{Any,Real}) = y_with_ladj
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where {F<:Union{typeof(map),typeof(broadcast)}}
return y_with_ladj
end

function _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj)
y = map_or_bc(_get_y, y_with_ladj)
ladj = sum(map_or_bc(_get_ladj, y_with_ladj))
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
y = map_or_bc(first, y_with_ladj)
ladj = sum(last, y_with_ladj)
(y, ladj)
end


# Need to use a type for this, type inference fails when using a pullback
# closure over YLT in the rrule, resulting in bad performance:
struct WithLadjOnMappedPullback{YLT} <: Function end
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
ys, ladj = unthunk(thunked_ΔΩ)
return NoTangent(), NoTangent(), map(y -> Tangent{YLT}(y, ladj), ys)
end

function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
YLT = eltype(y_with_ladj)
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
end

function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
map_or_bc = mapped_f.f
f = mapped_f.x
Expand Down
8 changes: 8 additions & 0 deletions test/test_with_ladj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ using Test

using LinearAlgebra

using ChangesOfVariables
using ChangesOfVariables: test_with_logabsdet_jacobian
using ChainRulesTestUtils

include("getjacobian.jl")

Expand Down Expand Up @@ -59,4 +61,10 @@ include("getjacobian.jl")
test_with_logabsdet_jacobian(f, x, getjacobian)
end
end

@testset "rrules" begin
for map_or_bc in (map, broadcast)
test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
end
end
end