Skip to content

Commit

Permalink
Make AdvancedMH compatible with AbstractMCMC 5
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Oct 27, 2023
1 parent 8ddb81e commit cb4d4a4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.7.6"
version = "0.7.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -23,18 +23,18 @@ AdvancedMHMCMCChainsExt = "MCMCChains"
AdvancedMHStructArraysExt = "StructArrays"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5"
DiffResults = "1"
Distributions = "0.20 - 0.25"
LinearAlgebra = "1.6 - 1.11"
Random = "1.6 - 1.11"
Distributions = "0.25"
FillArrays = "1"
ForwardDiff = "0.10"
LogDensityProblems = "2"
MCMCChains = "5, 6"
MCMCChains = "6.0.4"
Requires = "1"
StructArrays = "0.6"
julia = "1.6"
LinearAlgebra = "1.6"
Random = "1.6"

[extras]
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,21 +138,21 @@ AdvancedMH.jl implements the interface of [AbstractMCMC](https://github.com/Turi

```julia
# Sample 4 chains from the posterior serially, without thread or process parallelism.
chain = sample(model, RWMH(init_params), MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCSerial(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple threads.
chain = sample(model, RWMH(init_params), MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCThreads(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)

# Sample 4 chains from the posterior using multiple processes.
chain = sample(model, RWMH(init_params), MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
chain = sample(model, spl, MCMCDistributed(), 100000, 4; param_names=["μ","σ"], chain_type=Chains)
```

## Metropolis-adjusted Langevin algorithm (MALA)

AdvancedMH.jl also offers an implementation of [MALA](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm) if the `ForwardDiff` and `DiffResults` packages are available.

A `MALA` sampler can be constructed by `MALA(proposal)` where `proposal` is a function that
takes the gradient computed at the current sample. It is required to specify an initial sample `init_params` when calling `sample`.
takes the gradient computed at the current sample. It is required to specify an initial sample `initial_params` when calling `sample`.

```julia
# Import the package.
Expand Down Expand Up @@ -180,7 +180,7 @@ model = DensityModel(density)
spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior.
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain = sample(model, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```

### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)
Expand All @@ -192,5 +192,5 @@ Using our implementation of the `LogDensityProblems.jl` interface above:
```julia
using LogDensityProblemsAD
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
sample(model_with_ad, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
sample(model_with_ad, spl, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```
6 changes: 3 additions & 3 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ spl = MetropolisHastings(proposal)
When using `MetropolisHastings` with the function `sample`, the following keyword
arguments are allowed:
- `init_params` defines the initial parameterization for your model. If
- `initial_params` defines the initial parameterization for your model. If
none is given, the initial parameters will be drawn from the sampler's proposals.
- `param_names` is a vector of strings to be assigned to parameters. This is only
used if `chain_type=Chains`.
Expand Down Expand Up @@ -77,10 +77,10 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModelOrLogDensityModel,
sampler::MHSampler;
init_params=nothing,
initial_params=nothing,
kwargs...
)
params = init_params === nothing ? propose(rng, sampler, model) : init_params
params = initialparams === nothing ? propose(rng, sampler, model) : initial_params
transition = AdvancedMH.transition(sampler, model, params)
return transition, transition
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ include("util.jl")
val = [0.4, 1.2]

# Sample from the posterior.
chain1 = sample(model, spl1, 10, init_params = val)
chain1 = sample(model, spl1, 10, initial_params = val)

@test chain1[1].params == val
end
Expand Down Expand Up @@ -265,7 +265,7 @@ include("util.jl")
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior with initial parameters.
chain1 = sample(model, spl1, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])

@test mean(chain1.μ) 0.0 atol=0.1
@test mean(chain1.σ) 1.0 atol=0.1
Expand All @@ -276,7 +276,7 @@ include("util.jl")
admodel,
spl1,
100000;
init_params=ones(2),
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"]
)
Expand Down

0 comments on commit cb4d4a4

Please sign in to comment.