Skip to content

Remove LogDensityProblemsAD Extension 2 #811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 15, 2025
Merged

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 15, 2025

In this comment #806 (comment) I wrote:

(1) ...

(2) Would it be possible to combine LogDensityFunction + LogDensityFunctionWithGrad into a single struct where the adtype is Union{Nothing,AbstractADType}. adtype=nothing would serve the same purpose as the current LogDensityFunction, and adtype <: AbstractADType would replace LogDensityFunctionWithGrad. One could then have a setadtype function which would, well, set the adtype, but also do the gradient preparation step. This would bring us one step closer to having adtypes be part of the model rather than the sampler, cf. TuringLang/AbstractMCMC.jl#158. The idea would be that if the sampler itself carried information about the adtype, it would use its own adtype to override the model's one by calling setadtype. We could then gradually remove this possibility by raising a warning / depwarn / erroring.

This PR implements this. Benchmarks to follow.

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 15, 2025

Benchmarks look good in the sense that they aren't any worse than on #806, and if anything, Mooncake is now actually a bit faster (maybe having fewer structs and stuff to differentiate through helps? I don't know):

model                                       adtype                          new4   old    new4/old     mean(new4/old)  stdev(new4/old)
demo_dot_assume_dot_observe                 AutoForwardDiff()               1.71   1.68   1.017857143  1.082413028     0.06451764
demo_assume_index_observe                   AutoForwardDiff()               1.24   1.16   1.068965517                  
demo_assume_multivariate_observe            AutoForwardDiff()               1.13   1.01   1.118811881                  
demo_dot_assume_observe_index               AutoForwardDiff()               1.88   1.57   1.197452229                  
demo_assume_dot_observe                     AutoForwardDiff()               0.78   0.75   1.04                         
demo_assume_multivariate_observe_literal    AutoForwardDiff()               1.14   1.03   1.106796117                  
demo_dot_assume_observe_index_literal       AutoForwardDiff()               1.82   1.54   1.181818182                  
demo_assume_dot_observe_literal             AutoForwardDiff()               0.86   0.79   1.088607595                  
demo_assume_observe_literal                 AutoForwardDiff()               0.81   0.75   1.08                         
demo_assume_submodel_observe_index_literal  AutoForwardDiff()               1.86   1.96   0.948979592                  
demo_dot_assume_observe_submodel            AutoForwardDiff()               2.39   2.34   1.021367521                  
demo_dot_assume_dot_observe_matrix          AutoForwardDiff()               1.83   1.69   1.082840237                  
demo_dot_assume_matrix_dot_observe_matrix   AutoForwardDiff()               1.69   1.51   1.119205298                  
demo_assume_matrix_dot_observe_matrix       AutoForwardDiff()               1.2    1.11   1.081081081                  
demo_dot_assume_dot_observe                 AutoMooncake{Nothing}(nothing)  9.71   11.25  0.863111111  0.933785169     0.055335369
demo_assume_index_observe                   AutoMooncake{Nothing}(nothing)  6.77   7.81   0.866837388                  
demo_assume_multivariate_observe            AutoMooncake{Nothing}(nothing)  6.56   8.01   0.81897628                   
demo_dot_assume_observe_index               AutoMooncake{Nothing}(nothing)  8.92   9.64   0.925311203                  
demo_assume_dot_observe                     AutoMooncake{Nothing}(nothing)  5.58   6.06   0.920792079                  
demo_assume_multivariate_observe_literal    AutoMooncake{Nothing}(nothing)  6.52   6.64   0.981927711                  
demo_dot_assume_observe_index_literal       AutoMooncake{Nothing}(nothing)  9.07   9.6    0.944791667                  
demo_assume_dot_observe_literal             AutoMooncake{Nothing}(nothing)  5.58   5.84   0.955479452                  
demo_assume_observe_literal                 AutoMooncake{Nothing}(nothing)  5.61   5.6    1.001785714                  
demo_assume_submodel_observe_index_literal  AutoMooncake{Nothing}(nothing)  9.39   9.82   0.956211813                  
demo_dot_assume_observe_submodel            AutoMooncake{Nothing}(nothing)  12.58  13.96  0.901146132                  
demo_dot_assume_dot_observe_matrix          AutoMooncake{Nothing}(nothing)  10.96  11.13  0.984725966                  
demo_dot_assume_matrix_dot_observe_matrix   AutoMooncake{Nothing}(nothing)  10.65  10.6   1.004716981                  
demo_assume_matrix_dot_observe_matrix       AutoMooncake{Nothing}(nothing)  7.89   8.33   0.947178872                  
demo_dot_assume_dot_observe                 AutoReverseDiff()               19.58  19.38  1.010319917  0.982803156     0.02957993
demo_assume_index_observe                   AutoReverseDiff()               21.67  22.88  0.947115385                  
demo_assume_multivariate_observe            AutoReverseDiff()               19.25  19.71  0.976661593                  
demo_dot_assume_observe_index               AutoReverseDiff()               21.46  21.25  1.009882353                  
demo_assume_dot_observe                     AutoReverseDiff()               15.38  16.08  0.956467662                  
demo_assume_multivariate_observe_literal    AutoReverseDiff()               19.71  20.08  0.981573705                  
demo_dot_assume_observe_index_literal       AutoReverseDiff()               21.5   21.62  0.994449584                  
demo_assume_dot_observe_literal             AutoReverseDiff()               15.96  16.29  0.979742173                  
demo_assume_observe_literal                 AutoReverseDiff()               16.17  16.21  0.997532387                  
demo_assume_submodel_observe_index_literal  AutoReverseDiff()               21.88  21.08  1.037950664                  
demo_dot_assume_observe_submodel            AutoReverseDiff()               21.42  22     0.973636364                  
demo_dot_assume_dot_observe_matrix          AutoReverseDiff()               20.92  20.71  1.010140029                  
demo_dot_assume_matrix_dot_observe_matrix   AutoReverseDiff()               22.62  23.71  0.954027836                  
demo_assume_matrix_dot_observe_matrix       AutoReverseDiff()               20.38  21.92  0.929744526                  
demo_dot_assume_dot_observe                 AutoReverseDiff(compile=true)   5.71   6.18   0.92394822   0.952204654     0.048199875
demo_assume_index_observe                   AutoReverseDiff(compile=true)   6.56   7.51   0.873501997                  
demo_assume_multivariate_observe            AutoReverseDiff(compile=true)   5.58   6.15   0.907317073                  
demo_dot_assume_observe_index               AutoReverseDiff(compile=true)   7.07   7.68   0.920572917                  
demo_assume_dot_observe                     AutoReverseDiff(compile=true)   4.98   5.58   0.892473118                  
demo_assume_multivariate_observe_literal    AutoReverseDiff(compile=true)   5.92   5.91   1.001692047                  
demo_dot_assume_observe_index_literal       AutoReverseDiff(compile=true)   7.21   7.29   0.989026063                  
demo_assume_dot_observe_literal             AutoReverseDiff(compile=true)   5.49   5.47   1.003656307                  
demo_assume_observe_literal                 AutoReverseDiff(compile=true)   5.63   5.65   0.996460177                  
demo_assume_submodel_observe_index_literal  AutoReverseDiff(compile=true)   7.49   7.28   1.028846154                  
demo_dot_assume_observe_submodel            AutoReverseDiff(compile=true)   6.19   6.69   0.925261584                  
demo_dot_assume_dot_observe_matrix          AutoReverseDiff(compile=true)   6.52   6.94   0.939481268                  
demo_dot_assume_matrix_dot_observe_matrix   AutoReverseDiff(compile=true)   7.67   7.76   0.988402062                  
demo_assume_matrix_dot_observe_matrix       AutoReverseDiff(compile=true)   5.82   6.19   0.940226171                  

@penelopeysm penelopeysm force-pushed the py/no-ldp-ad-2 branch 2 times, most recently from 8c98a73 to 892b097 Compare February 15, 2025 15:17
@TuringLang TuringLang deleted a comment from github-actions bot Feb 15, 2025
Copy link

codecov bot commented Feb 15, 2025

Codecov Report

Attention: Patch coverage is 85.71429% with 5 lines in your changes missing coverage. Please review.

Project coverage is 85.83%. Comparing base (0e24d97) to head (e4cd311).
Report is 1 commits behind head on py/no-ldp-ad.

Files with missing lines Patch % Lines
src/logdensityfunction.jl 85.71% 5 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff                @@
##           py/no-ldp-ad     #811      +/-   ##
================================================
+ Coverage         85.79%   85.83%   +0.03%     
================================================
  Files                35       35              
  Lines              4197     4200       +3     
================================================
+ Hits               3601     3605       +4     
+ Misses              596      595       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@penelopeysm
Copy link
Member Author

Since tests are passing, and the perf is equivalent / better to before, I'm going to merge this into the other PR.

@penelopeysm penelopeysm marked this pull request as ready for review February 15, 2025 16:57
@penelopeysm penelopeysm merged commit ba6ba7c into py/no-ldp-ad Feb 15, 2025
3 of 16 checks passed
@penelopeysm penelopeysm deleted the py/no-ldp-ad-2 branch February 15, 2025 16:57
penelopeysm added a commit that referenced this pull request Feb 19, 2025
* Remove LogDensityProblemsAD

* Implement LogDensityFunctionWithGrad in place of ADgradient

* Dynamically decide whether to use closure vs constant

* Combine LogDensityFunction{,WithGrad} into one (#811)

* Warn if unsupported AD type is used

* Update changelog

* Update DI compat bound

Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>

* Don't store with_closure inside LogDensityFunction

Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>

* setadtype --> LogDensityFunction

* Re-add ForwardDiffExt (including tests)

* Add more tests for coverage

---------

Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant