-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some interface functions to support the new Gibbs sampler in Turing #144
Changes from all commits
dcf1da9
cdaa663
57275f5
26027ea
6ebab49
e1099f9
95d781b
f05f293
590d37f
3afc232
55dbab5
67ff8e8
f758a4c
7d0ba7c
1ab6dd9
923c116
63028d3
44de81c
be43178
1a6e0d5
6b60b72
c58b39a
ac0ce7a
280eaf1
b262ea9
c47ade4
8d29ad3
c28a75a
1382054
8962d40
dc6001c
e194108
64eb0e4
9361c39
39c4d87
7f889cf
62a2332
6132f0c
af208bc
fd472df
4306aee
b798b2e
3ed5cb3
6fde198
c7f577d
8f11a15
48a160d
8d74889
bceb510
c177271
bdba893
e7e2870
80df187
076e431
4293868
1cee0ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# On `AbstractMCMC` Interface Supporting `Gibbs` | ||
|
||
This is written at Oct 1st, 2024. Version of packages described in this passage are: | ||
|
||
* `Turing.jl`: 0.34.1 | ||
|
||
In this passage, `Gibbs` refers to `Experimental.Gibbs`. | ||
|
||
## Current Implementation of `Gibbs` in `Turing` | ||
|
||
Here I describe the current implementation of `Gibbs` in `Turing` and the interface it requires from its sampler states. | ||
|
||
### Interface 1: `getparams` | ||
|
||
From the [definition of `GibbsState`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/experimental/gibbs.jl#L244-L248), we can see that a `vi::DynamicPPL.AbstractVarInfo` field is used to keep track of the names and values of parameters and the log density. The `states` field collects the sampler-specific *state*s. | ||
|
||
(The *link*ing of *varinfo*s is omitted in this discussion.) | ||
A local `VarInfo` is initially created with `DynamicPPL.subset(::VarInfo, ::Vector{<:VarName})` to make the conditioned model. After the Gibbs step, an updated `varinfo` is obtained by calling `Turing.Inference.varinfo` on the sampler state. | ||
|
||
For samplers and their states defined in `Turing` (including `DynamicHMC`, as `DynamicNUTSState` is defined by `Turing` in the package extension), we (à la `Turing.jl` package) assume that the *state*s all have a field called `vi`. Then `varinfo(_some_sampler_state_)` is simply `varinfo(state) = state.vi` (defined in [`src/mcmc/gibbs.jl`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/gibbs.jl#L97)). (`GibbsState` conforms to this assumption.) | ||
|
||
For `ExternalSamplers`, we currently only support `AdvancedHMC` and `AdvancedMH`. The mechanism is as follows: at the end of the `step` call with an external sampler, [`transition_to_turing` and `state_to_turing` are called](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/abstractmcmc.jl#L147). These two functions then call `getparams` on the sampler state of the external samplers. `getparams` for `AdvancedHMC.HMCState` and `AdvancedMH.Transition` (`AdvancedMH` uses `Transition` as state) are defined in `abstractmcmc.jl`. | ||
|
||
Thus, the first interface emerges: `getparams`. As `getparams` is designed to be implemented by a sampler that works with the `LogDensityProblems` interface, it makes sense for `getparams` to return a vector of `Real`s. The `logdensity_problem` should then be responsible for performing the transformation between its underlying representation and the vector of `Real`s. | ||
|
||
It's worth noting that: | ||
|
||
* `getparams` is not a function specific for `Gibbs`. It is required for the current support of external samplers. | ||
* There is another [`getparams`](https://github.com/TuringLang/Turing.jl/blob/3c91eec43176d26048b810aae0f6f2fac0686cfa/src/mcmc/Inference.jl#L328-L351) in `Turing.jl` that takes *model* and *varinfo*, then returns a `NamedTuple`. | ||
|
||
### Interface 2: `recompute_logp!!` | ||
|
||
Consider a model with multiple groups of variables, say $\theta_1, \theta_2, \ldots, \theta_k$. At the beginning of the $t$-th Gibbs step, the model parameters in the `GibbsState` are typically updated and different from the $(t-1)$-th step. The `GibbsState` maintains $k$ sub-states, one for each variable group, denoted as $\text{state}_{t,1}, \text{state}_{t,2}, \ldots, \text{state}_{t,k}$. | ||
|
||
The parameter values in each sub-state, i.e., $\theta_{t,i}$ in $\text{state}_{t,i}$, are always in sync with the corresponding values in the `GibbsState`. At the end of the $t$-th Gibbs step, $\text{state}_{t,i}$ will store the log density of the $i$-th variable group conditioned on all other variable groups at their values from step $t$, denoted as $\log p(\theta_{t,i} \mid \theta_{t,-i})$. This log density is equal to the joint log density of the whole model evaluated at the current parameter values $(\theta_{t,1}, \ldots, \theta_{t,k})$. | ||
|
||
However, the log density stored in each sub-state is in general not equal to the log density needed for the next Gibbs step at $t+1$, i.e., $\log p(\theta_{t,i} \mid \theta_{t+1,-i})$. This is because the values of the other variable groups $\theta_{-i}$ will have been updated in the Gibbs step from $t$ to $t+1$, changing the conditioning set. Therefore, the log density typically needs to be recomputed at each Gibbs step to account for the updated values of the conditioning variables. | ||
|
||
Only in certain special cases, the recomputation can be skipped. For example, in a Metropolis-Hastings step where the proposal is rejected for all other variable groups, i.e., $\theta_{t+1,-i} = \theta_{t,-i}$, the log density $\log p(\theta_{t,i} \mid \theta_{t,-i})$ remains valid and doesn't need to be recomputed. | ||
|
||
The `recompute_logp!!` function in `abstractmcmc.jl` handles this recomputation. It takes an updated conditioned log density function $\log p(\cdot \mid \theta_{t+1,j})$ and the parameter values $\theta_{t,i}$ stored in $\text{state}_{t,i}$ to compute the updated log density $\log p(\theta_{t,i} \mid \theta_{t+1,j})$. | ||
|
||
## Proposed Interface | ||
|
||
The two functions `getparams` and `recompute_logp!!` form a minimal interface to support the `Gibbs` implementation. However, there are concerns about introducing them directly into `AbstractMCMC`. The main reason is that `AbstractMCMC` is a root dependency of the `Turing` packages, so we want to be very careful with new releases. | ||
|
||
Here, some alternative functions that achieve the same functionality as `getparams` and `recompute_logp!!` are proposed, but without introducing new interface functions. | ||
|
||
For `getparams`, we can use `Base.vec`. It is a `Base` function, so there's no need to export anything from `AbstractMCMC`. Since `getparams` should return a vector, using `vec` makes sense. The concern is that, officially, `Base.vec` is defined for `AbstractArray`, so it remains a question whether we should only introduce `vec` in the absence of other `AbstractArray` interfaces. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd much prefer an explicit method in AbstractMCMC (uncertain if we want to export it 🤷 but probably make it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can resonate, issue with |
||
|
||
For `recompute_logp!!`, we could overload `LogDensityProblems.logdensity(logdensity_model::AbstractMCMC.LogDensityModel, state::State; recompute_logp=true)` to compute the log probability. If `recompute_logp` is `true`, it should recompute the log probability of the state. Otherwise, it could use the log probability stored in the state. To allow updating the log probability stored in the state, samplers should define outer constructor for their state types `StateType(state::StateType, logdensity=logp)` that takes an existing `state` and a log probability value `logp`, and returns a new state of the same type with the updated log probability. | ||
|
||
While overloading `LogDensityProblems.logdensity` to take a state object instead of a vector for the second argument somewhat deviates from the interface in `LogDensityProblems`, it provides a clean and extensible solution for handling log probability recomputation within the existing interface. | ||
Comment on lines
+51
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But here we're introducing kwargs, etc. which is really not a part of the Why would we do this vs. just using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mainly to not make any changes to the interface, so these are just "recommendations" I think |
||
|
||
An example demonstrating these interfaces is provided in `src/state_interface.md`. | ||
|
||
## A More Standalone `Gibbs` Implementation | ||
|
||
`AbstractMCMC.Gibbs` should not manage a `variable name → sampler` but rather `range → sampler`, i.e. it maintains a vector of parameter values, while a higher-level interface like `AbstractPPL` / `DynamicPPL` should manage both the name and transformations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair, but if we now make a release where we assume that certain functionality is overloaded, then that seems strictly worse, no?