Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: Interest in faster logpdf for BernoulliLogit (logistic regression) #1890

Closed
svilupp opened this issue Oct 5, 2022 · 9 comments
Closed

Comments

@svilupp
Copy link

svilupp commented Oct 5, 2022

Hi everyone,

One of the learnings from the benchmarking against Numpyro was that they use an interesting re-parametrization of logpdf for Bernoulli from Tensorflow (reference)

A quick test suite shows that it's equivalent with very high precision.
Benchmark suggests it's c. 4x faster (see MWE), which would translate in faster logistic regression sampling (see benchmark results in the linked post above).

Would there be interest if I opened a PR to update BernoulliLogit to have this implementation?

Note: the final implementation might be a bit slower than 4x because of the dispatch etc (not sure all would be inlined).


MWE:

## Define models
using Turing
using LogExpFunctions: logit, logistic,log1pexp
using Test

# custom implementation
# from https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
jax_logpdf(x,y)=-(max(0,x) +log1pexp(-abs(x)) -x*y)

# Benchmark logprob call itself
logits=randn(10000)
y=convert.(Int,(rand(10000).>0.5))
@btime sum(logpdf.(BernoulliLogit.($logits),$y));
# 467.792 μs (2 allocations: 78.17 KiB)
@btime sum(jax_logpdf.($logits,$y));
# 12.833 μs (1 allocation: 7.94 KiB)

# single run
@btime logpdf(BernoulliLogit($logits[1]),$y[1]);
# 58.749 ns (0 allocations: 0 bytes)
@btime jax_logpdf($logits[1],$y[1]);
# 13.402 ns (0 allocations: 0 bytes)

@testset "custom logpdf for BernoulliLogit" begin
    for (logit_val,y) in Iterators.product(range(-3,3,length=100), [0.,1.])
        @test jax_logpdf(logit_val,y) ≈ logpdf(BernoulliLogit(logit_val),y) atol=1e-14
    end
end
# Test Summary:                    | Pass  Total  Time
# custom logpdf for BernoulliLogit |  200    200  0.0s


@yebai
Copy link
Member

yebai commented Oct 5, 2022

Would there be interest if I opened a PR to update BernoulliLogit to have this implementation?

@svilupp Many thanks for the benchmarks and for sharing your experiences on PPLs. Of course, we would appreciate a PR.

@devmotion
Copy link
Member

IMO such distributions should not be defined in Turing but Distributions.

@yebai
Copy link
Member

yebai commented Oct 5, 2022

Seconded - I suggest that we update the implementation in Turing, and submit a PR for Distributions. PR's merging cycle is much faster from the Turing side than Distributions.

@svilupp
Copy link
Author

svilupp commented Oct 5, 2022

Noted. I’ll draft the PR today or tomorrow and also open the discussion on Distributions.jl

@devmotion
Copy link
Member

I already have something locally (did that when seeing your benchmarks on discourse 😄), I'll push it to a branch and open a WIP PR to Distributions this afternoon.

@devmotion
Copy link
Member

I opened JuliaStats/Distributions.jl#1623.

@devmotion
Copy link
Member

I reran the benchmarks above with JuliaStats/Distributions.jl#1623:

julia> using Distributions

julia> using BenchmarkTools

julia> using LogExpFunctions: logit, logistic,log1pexp

julia> using Test

julia> jax_logpdf(x, y) = -(max(0,x) + log1pexp(-abs(x)) - x*y)
jax_logpdf (generic function with 1 method)

julia> logits=randn(10000);

julia> y = convert.(Int, (rand(10000) .> 0.5));

julia> @btime sum(logpdf.(BernoulliLogit.($logits), $y));
  151.728 μs (2 allocations: 78.17 KiB)

julia> @btime sum(jax_logpdf.($logits, $y));
  158.650 μs (2 allocations: 78.17 KiB)

julia> @btime logpdf(BernoulliLogit($logits[1]), $y[1]);
  15.309 ns (0 allocations: 0 bytes)

julia> @btime jax_logpdf($logits[1], $y[1]);
  16.569 ns (0 allocations: 0 bytes)

julia> @testset "custom logpdf for BernoulliLogit" begin
           for (logit_val,y) in Iterators.product(range(-3,3,length=100), [0., 1.])
               @test jax_logpdf(logit_val, y)  logpdf(BernoulliLogit(logit_val), y) atol=1e-14
           end
       end
Test Summary:                    | Pass  Total  Time
custom logpdf for BernoulliLogit |  200    200  0.0s

The version in the Distributions PR is a bit faster since it avoids max, abs, and the multiplication.

@svilupp
Copy link
Author

svilupp commented Oct 11, 2022

That is awesome! Thank you.

I have a bit naive question - if we have runtime conditional statements, can we still use it with Reversediff with tape compilation?

eg, logpdf(d::BernoulliLogit, x::Bool) = x ? logsuccprob(d) : logfailprob(d) # taken from the Distributions.jl PR

@yebai
Copy link
Member

yebai commented Nov 12, 2022

Fixed by #1892

@yebai yebai closed this as completed Nov 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants