-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal: Interest in faster logpdf for BernoulliLogit (logistic regression) #1890
Comments
@svilupp Many thanks for the benchmarks and for sharing your experiences on PPLs. Of course, we would appreciate a PR. |
IMO such distributions should not be defined in Turing but Distributions. |
Seconded - I suggest that we update the implementation in Turing, and submit a PR for Distributions. PR's merging cycle is much faster from the Turing side than Distributions. |
Noted. I’ll draft the PR today or tomorrow and also open the discussion on Distributions.jl |
I already have something locally (did that when seeing your benchmarks on discourse 😄), I'll push it to a branch and open a WIP PR to Distributions this afternoon. |
I opened JuliaStats/Distributions.jl#1623. |
I reran the benchmarks above with JuliaStats/Distributions.jl#1623: julia> using Distributions
julia> using BenchmarkTools
julia> using LogExpFunctions: logit, logistic,log1pexp
julia> using Test
julia> jax_logpdf(x, y) = -(max(0,x) + log1pexp(-abs(x)) - x*y)
jax_logpdf (generic function with 1 method)
julia> logits=randn(10000);
julia> y = convert.(Int, (rand(10000) .> 0.5));
julia> @btime sum(logpdf.(BernoulliLogit.($logits), $y));
151.728 μs (2 allocations: 78.17 KiB)
julia> @btime sum(jax_logpdf.($logits, $y));
158.650 μs (2 allocations: 78.17 KiB)
julia> @btime logpdf(BernoulliLogit($logits[1]), $y[1]);
15.309 ns (0 allocations: 0 bytes)
julia> @btime jax_logpdf($logits[1], $y[1]);
16.569 ns (0 allocations: 0 bytes)
julia> @testset "custom logpdf for BernoulliLogit" begin
for (logit_val,y) in Iterators.product(range(-3,3,length=100), [0., 1.])
@test jax_logpdf(logit_val, y) ≈ logpdf(BernoulliLogit(logit_val), y) atol=1e-14
end
end
Test Summary: | Pass Total Time
custom logpdf for BernoulliLogit | 200 200 0.0s The version in the Distributions PR is a bit faster since it avoids |
That is awesome! Thank you. I have a bit naive question - if we have runtime conditional statements, can we still use it with Reversediff with tape compilation?
|
Fixed by #1892 |
Hi everyone,
One of the learnings from the benchmarking against Numpyro was that they use an interesting re-parametrization of logpdf for Bernoulli from Tensorflow (reference)
A quick test suite shows that it's equivalent with very high precision.
Benchmark suggests it's c. 4x faster (see MWE), which would translate in faster logistic regression sampling (see benchmark results in the linked post above).
Would there be interest if I opened a PR to update BernoulliLogit to have this implementation?
Note: the final implementation might be a bit slower than 4x because of the dispatch etc (not sure all would be inlined).
MWE:
The text was updated successfully, but these errors were encountered: