Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some interface functions to support the new Gibbs sampler in Turing #144

Closed
wants to merge 56 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
dcf1da9
very incomplete draft
sunxd3 Jul 12, 2024
cdaa663
update `getparams`
sunxd3 Jul 12, 2024
57275f5
Upstream `condition` and `decondition` from `AbstractPPL`
sunxd3 Jul 18, 2024
26027ea
remove `condition` and `decondition`
sunxd3 Jul 22, 2024
6ebab49
add Compat to make new interface functions public
sunxd3 Jul 22, 2024
e1099f9
bump minor version
sunxd3 Jul 22, 2024
95d781b
bump minor version instead
sunxd3 Jul 22, 2024
f05f293
unfinished gibbs example
sunxd3 Aug 6, 2024
590d37f
some updates
sunxd3 Aug 14, 2024
3afc232
more progress; still need to deal with w being on simplex
sunxd3 Aug 15, 2024
55dbab5
bit of format
sunxd3 Aug 15, 2024
67ff8e8
results is wrong
sunxd3 Aug 15, 2024
f758a4c
Apply suggestions from code review
sunxd3 Aug 15, 2024
7d0ba7c
add hierarchical normal problem
sunxd3 Aug 22, 2024
1ab6dd9
some updates; add doc
sunxd3 Aug 23, 2024
923c116
move folder into test
sunxd3 Aug 23, 2024
63028d3
setup as a test
sunxd3 Aug 23, 2024
44de81c
add to doc
sunxd3 Aug 23, 2024
be43178
format
sunxd3 Aug 23, 2024
1a6e0d5
bump patch version
sunxd3 Aug 23, 2024
6b60b72
reverse version bump -- already done
sunxd3 Aug 23, 2024
c58b39a
remove dep on `Compat`
sunxd3 Aug 23, 2024
ac0ce7a
updates to doc
sunxd3 Aug 23, 2024
280eaf1
update gibbs to add to the src folder
sunxd3 Sep 8, 2024
b262ea9
update mh code
sunxd3 Sep 8, 2024
c47ade4
update code further
sunxd3 Sep 8, 2024
8d29ad3
fix test errors
sunxd3 Sep 9, 2024
c28a75a
format
sunxd3 Sep 9, 2024
1382054
fix doctest error
sunxd3 Sep 9, 2024
8962d40
tidy up
sunxd3 Sep 9, 2024
dc6001c
updates
sunxd3 Sep 17, 2024
e194108
Update test/gibbs_example/mh.jl
sunxd3 Sep 17, 2024
64eb0e4
fix error
sunxd3 Sep 17, 2024
9361c39
typo fix
sunxd3 Sep 17, 2024
39c4d87
Update src/gibbs.jl
sunxd3 Sep 18, 2024
7f889cf
rename gibbs test file to prepare for moving
sunxd3 Sep 20, 2024
62a2332
move gibbs.jl
sunxd3 Sep 20, 2024
6132f0c
update code
sunxd3 Sep 20, 2024
af208bc
updates
sunxd3 Sep 22, 2024
fd472df
rework the code; still not type stable
sunxd3 Sep 22, 2024
4306aee
fix test
sunxd3 Sep 22, 2024
b798b2e
update doc -- need proofread
sunxd3 Sep 22, 2024
3ed5cb3
fix 1.6 struct field splatting compat issue
sunxd3 Sep 22, 2024
6fde198
update code and doc
sunxd3 Sep 27, 2024
c7f577d
relax test error
sunxd3 Sep 28, 2024
8f11a15
rename gibbs markdown file
sunxd3 Sep 28, 2024
48a160d
change title
sunxd3 Sep 28, 2024
8d74889
update code and note
sunxd3 Oct 1, 2024
bceb510
fix doc example
sunxd3 Oct 1, 2024
c177271
try to fix doc example error
sunxd3 Oct 1, 2024
bdba893
fix doc deps
sunxd3 Oct 1, 2024
e7e2870
fix more doc example error
sunxd3 Oct 1, 2024
80df187
minor update
sunxd3 Oct 1, 2024
076e431
Apply suggestions from code review
sunxd3 Oct 3, 2024
4293868
Update docs/src/state_interface.md
sunxd3 Oct 3, 2024
1cee0ab
Update docs/src/state_interface.md
sunxd3 Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.2.0"
version = "5.3.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -20,22 +21,27 @@ TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
AbstractPPL = "0.8"
BangBang = "0.3.19, 0.4"
ConsoleProgressMonitor = "0.1"
FillArrays = "1"
LogDensityProblems = "2"
LoggingExtras = "0.4, 0.5, 1"
MCMCChains = "6"
ProgressLogging = "0.1"
StatsBase = "0.32, 0.33, 0.34"
TerminalLoggers = "0.1"
Transducers = "0.4.30"
julia = "1.6"

[extras]
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["FillArrays", "IJulia", "Statistics", "Test"]
test = ["AbstractPPL","FillArrays", "Distributions", "IJulia", "MCMCChains", "Statistics", "Test"]
59 changes: 59 additions & 0 deletions design_notes/on_gibbs_implementation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# On `AbstractMCMC` Interface Supporting `Gibbs`

This is written at Oct 1st, 2024. Version of packages described in this passage are:

* `Turing.jl`: 0.34.1

In this passage, `Gibbs` refers to `Experimental.Gibbs`.

## Current Implementation of `Gibbs` in `Turing`

Here I describe the current implementation of `Gibbs` in `Turing` and the interface it requires from its sampler states.

### Interface 1: `getparams`

From the [definition of `GibbsState`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L244-L248), we can see that a `vi::DynamicPPL.AbstractVarInfo` field is used to keep track of the names and values of parameters and the log density. The `states` field collects the sampler-specific *state*s.

(The *link*ing of *varinfo*s is omitted in this discussion.)
A local `VarInfo` is initially created with `DynamicPPL.subset(::VarInfo, ::Vector{<:VarName})` to make the conditioned model. After the Gibbs step, an updated `varinfo` is obtained by calling `Turing.Inference.varinfo` on the sampler state.

For samplers and their states defined in `Turing` (including `DynamicHMC`, as `DynamicNUTSState` is defined by `Turing` in the package extension), we (à la `Turing.jl` package) assume that the *state*s all have a field called `vi`. Then `varinfo(_some_sampler_state_)` is simply `varinfo(state) = state.vi` (defined in [`src/mcmc/gibbs.jl`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/gibbs.jl#L97)). (`GibbsState` conforms to this assumption.)

For `ExternalSamplers`, we currently only support `AdvancedHMC` and `AdvancedMH`. The mechanism is as follows: at the end of the `step` call with an external sampler, [`transition_to_turing` and `state_to_turing` are called](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/abstractmcmc.jl#L147). These two functions then call `getparams` on the sampler state of the external samplers. `getparams` for `AdvancedHMC.HMCState` and `AdvancedMH.Transition` (`AdvancedMH` uses `Transition` as state) are defined in `abstractmcmc.jl`.

Thus, the first interface emerges: `getparams`. As `getparams` is designed to be implemented by a sampler that works with the `LogDensityProblems` interface, it makes sense for `getparams` to return a vector of `Real`s. The `logdensity_problem` should then be responsible for performing the transformation between its underlying representation and the vector of `Real`s.

It's worth noting that:

* `getparams` is not a function specific for `Gibbs`. It is required for the current support of external samplers.
* There is another [`getparams`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/Inference.jl#L328-L351) in `Turing.jl` that takes *model* and *varinfo*, then returns a `NamedTuple`.

### Interface 2: `recompute_logp!!`

Consider a model with multiple groups of variables, say $\theta_1, \theta_2, \ldots, \theta_k$. At the beginning of the $t$-th Gibbs step, the model parameters in the `GibbsState` are typically updated and different from the $(t-1)$-th step. The `GibbsState` maintains $k$ sub-states, one for each variable group, denoted as $\text{state}_{t,1}, \text{state}_{t,2}, \ldots, \text{state}_{t,k}$.

The parameter values in each sub-state, i.e., $\theta_{t,i}$ in $\text{state}_{t,i}$, are always in sync with the corresponding values in the `GibbsState`. At the end of the $t$-th Gibbs step, $\text{state}_{t,i}$ will store the log density of the $i$-th variable group conditioned on all other variable groups at their values from step $t$, denoted as $\log p(\theta_{t,i} \mid \theta_{t,-i})$. This log density is equal to the joint log density of the whole model evaluated at the current parameter values $(\theta_{t,1}, \ldots, \theta_{t,k})$.

However, the log density stored in each sub-state is in general not equal to the log density needed for the next Gibbs step at $t+1$, i.e., $\log p(\theta_{t,i} \mid \theta_{t+1,-i})$. This is because the values of the other variable groups $\theta_{-i}$ will have been updated in the Gibbs step from $t$ to $t+1$, changing the conditioning set. Therefore, the log density typically needs to be recomputed at each Gibbs step to account for the updated values of the conditioning variables.

Only in certain special cases, the recomputation can be skipped. For example, in a Metropolis-Hastings step where the proposal is rejected for all other variable groups, i.e., $\theta_{t+1,-i} = \theta_{t,-i}$, the log density $\log p(\theta_{t,i} \mid \theta_{t,-i})$ remains valid and doesn't need to be recomputed.

The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. It takes an updated conditioned log density function $\log p(\cdot \mid \theta_{t+1,j})$ and the parameter values $\theta_{t,i}$ stored in $\text{state}_{t,i}$ to compute the updated log density $\log p(\theta_{t,i} \mid \theta_{t+1,j})$.

## Proposed Interface

The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is that AbstractMCMC is a root dependency of the Turing packages, so we want to be very careful with new releases.

Fair, but if we now make a release where we assume that certain functionality is overloaded, then that seems strictly worse, no?


Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions.

For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd much prefer an explicit method in AbstractMCMC (uncertain if we want to export it 🤷 but probably make it public). Anyone implementing this interface already has AbstractMCMC loaded, so really doesn't cost anything + avoids misuse of Base.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can resonate, issue with public is they still count as public interface, unsure if we need to make minor release


For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability.

While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface.
Comment on lines +51 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But here we're introducing kwargs, etc. which is really not a part of the LogDensityProblems.logdensity interface. It would also mean we would have to depend on LogDensityProblems.jl, which we're currently not doing (AFIAK).

Why would we do this vs. just using recompute_logp!! for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly to not make any changes to the interface, so these are just "recommendations"

I think AbstractMCMC depends on LogDensityProblems


An example demonstrating these interfaces is provided in `src/state_interface.md`.

## A More Standalone `Gibbs` Implementation

`AbstractMCMC.Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintains a vector of parameter values, while a higher-level interface like `AbstractPPL` / `DynamicPPL` should manage both the name and transformations.
4 changes: 4 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ makedocs(;
sitename="AbstractMCMC",
format=Documenter.HTML(),
modules=[AbstractMCMC],
pages=["Home" => "index.md", "api.md", "design.md"],
pages=["Home" => "index.md", "api.md", "design.md", "state_interface.md"],
checkdocs=:exports,
)

Expand Down
Loading
Loading