Skip to content

Commit

Permalink
Use ITP as root-finding algorithm in planar layer (#343)
Browse files Browse the repository at this point in the history
* Use ITP as root-finding algorithm in planar layer

* Use suggested hyperparameter

* Fix format

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Mooncake test

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
devmotion and github-actions[bot] authored Nov 6, 2024
1 parent e0f04fc commit 5c1feeb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.14.0"
version = "0.14.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -65,7 +65,7 @@ MappedArrays = "0.2.2, 0.3, 0.4"
Reexport = "0.2, 1"
Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Roots = "1.3.15, 2"
Statistics = "1"
Mooncake = "0.4.19"
Tracker = "0.2"
Expand Down
4 changes: 3 additions & 1 deletion src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
end

# Solve the root-finding problem
α0 = Roots.find_zero((lower, upper)) do α
# A value of `κ₁ = 0.2 / (upper - lower)` is suggested
# Ref: https://docs.rs/kurbo/0.11.1/kurbo/common/fn.solve_itp.html
α0 = Roots.find_zero((lower, upper), Roots.ITP(; κ₁=inv(10 * Δ))) do α
return α + wt_u_hat * tanh+ b) - wt_y
end

Expand Down
8 changes: 7 additions & 1 deletion test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
catch exc
# TODO(penelopeysm):
# @test_throws AssertionError (expr...) doesn't work, unclear why
@test exc isa AssertionError
# We use `isdefined` here since `hasproperty` for modules is not consistent with `getproperty`
# Ref https://github.com/JuliaLang/julia/issues/47150
if isdefined(Mooncake, :MooncakeRuleCompilationError)
@test exc isa getproperty(Mooncake, :MooncakeRuleCompilationError)
else
@test exc isa AssertionError
end
end
# TODO: The above @test_throws happens because of
# https://github.com/compintell/Mooncake.jl/issues/319. If that test
Expand Down

0 comments on commit 5c1feeb

Please sign in to comment.