Skip to content

Commit 1fd9f9e

Browse files
committed
Use Tangent in _with_ladj_on_mapped pullback
1 parent 534a4f2 commit 1fd9f9e

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/with_ladj.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,18 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(
8686
(y, ladj)
8787
end
8888

89-
function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
89+
90+
# Need to use a type for this, type inference fails when using a pullback
91+
# closure over YLT in the rrule, resulting in bad performance:
92+
struct WithLadjOnMappedPullback{YLT} <: Function end
93+
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
9094
ys, ladj = unthunk(thunked_ΔΩ)
91-
return NoTangent(), NoTangent(), tuple.(ys, ladj)
95+
return NoTangent(), NoTangent(), broadcast((y, l) -> Tangent{YLT}(y, l), ys, ladj)
9296
end
9397

9498
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
95-
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
99+
YLT = eltype(y_with_ladj)
100+
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
96101
end
97102

98103
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)

test/test_with_ladj.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Test
55

66
using LinearAlgebra
77

8-
using ChangesOfVariables: test_with_logabsdet_jacobian, _with_ladj_on_mapped
8+
using ChangesOfVariables
9+
using ChangesOfVariables: test_with_logabsdet_jacobian
910
using ChainRulesCore
1011

1112
include("getjacobian.jl")
@@ -63,10 +64,7 @@ include("getjacobian.jl")
6364

6465
@testset "rrules" begin
6566
for map_or_bc in (map, broadcast)
66-
x = [(1, 2), (3, 4), (5, 6)]
67-
y, back = rrule(_with_ladj_on_mapped, map_or_bc, x)
68-
@test y == ([1, 3, 5], 12) == _with_ladj_on_mapped(map_or_bc, x)
69-
@test back(@thunk ([7, 8, 9], 12)) == (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), [(7, 12), (8, 12), (9, 12)])
67+
test_rrule(ChangesOfVariables._with_ladj_on_mapped, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
7068
end
7169
end
7270
end

0 commit comments

Comments
 (0)