Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down Expand Up @@ -64,6 +65,7 @@ DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.39.1"
EllipticalSliceSampling = "0.5, 1, 2"
FlexiChains = "0.3.1"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
LinearAlgebra = "1"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ links = InterLinks(
"AbstractMCMC" => "https://turinglang.org/AbstractMCMC.jl/stable/",
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
"AdvancedVI" => "https://turinglang.org/AdvancedVI.jl/stable/",
"FlexiChains" => "https://pysm.dev/FlexiChains.jl/stable/",
"DistributionsAD" => "https://turinglang.org/DistributionsAD.jl/stable/",
"OrderedCollections" => "https://juliacollections.github.io/OrderedCollections.jl/stable/",
"Distributions" => "https://juliastats.org/Distributions.jl/stable/",
Expand Down
23 changes: 10 additions & 13 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@

## Module-wide re-exports

Turing.jl directly re-exports the entire public API of the following packages:

- [Distributions.jl](https://juliastats.org/Distributions.jl)
- [MCMCChains.jl](https://turinglang.org/MCMCChains.jl)

Please see the individual packages for their documentation.
Turing.jl directly re-exports the entire public API of [Distributions.jl](https://juliastats.org/Distributions.jl).
Please see its documentation for more details.

## Individual exports and re-exports

Expand Down Expand Up @@ -47,13 +43,14 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu

### Inference

| Exported symbol | Documentation | Description |
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------------- |
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from `MCMCChains.Chains` |
| Exported symbol | Documentation | Description |
|:----------------- |:------------------------------------------------------------------------- |:----------------------------------- |
| `sample` | [`StatsBase.sample`](https://turinglang.org/docs/usage/sampling-options/) | Sample from a model |
| `MCMCThreads` | [`AbstractMCMC.MCMCThreads`](@extref) | Run MCMC using multiple threads |
| `MCMCDistributed` | [`AbstractMCMC.MCMCDistributed`](@extref) | Run MCMC using multiple processes |
| `MCMCSerial` | [`AbstractMCMC.MCMCSerial`](@extref) | Run MCMC using without parallelism |
| `loadstate` | [`Turing.Inference.loadstate`](@ref) | Load saved state from an MCMC chain |
| `VNChain` | n/a | Alias for `FlexiChain{VarName}` |

### Samplers

Expand Down
7 changes: 5 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Reexport, ForwardDiff
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra
using Libtask
@reexport using Distributions, MCMCChains
@reexport using Distributions
using Compat: pkgversion

using AdvancedVI: AdvancedVI
Expand All @@ -16,6 +16,7 @@ using Accessors: Accessors
using StatsAPI: StatsAPI
using StatsBase: StatsBase
using AbstractMCMC
using FlexiChains

using Accessors: Accessors

Expand Down Expand Up @@ -172,6 +173,8 @@ export
MAP,
MLE,
# Chain save/resume
loadstate
loadstate,
# FlexiChains re-export
VNChain

end
4 changes: 2 additions & 2 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using DynamicPPL:
DefaultContext
using Distributions, Libtask, Bijectors
using DistributionsAD: VectorOfMultivariate
using FlexiChains: FlexiChains, VNChain
using LinearAlgebra
using ..Turing: PROGRESS, Turing
using StatsFuns: logsumexp
Expand All @@ -46,7 +47,6 @@ import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import Random
import MCMCChains
import StatsBase: predict

export Hamiltonian,
Expand Down Expand Up @@ -78,7 +78,7 @@ export Hamiltonian,
# Generic AbstractMCMC methods dispatch #
#########################################

const DEFAULT_CHAIN_TYPE = MCMCChains.Chains
const DEFAULT_CHAIN_TYPE = VNChain
include("abstractmcmc.jl")

####################
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using MCMCChains: MCMCChains

# TODO: Implement additional checks for certain samplers, e.g.
# HMC not supporting discrete parameters.
function _check_model(model::DynamicPPL.Model)
Expand Down
35 changes: 21 additions & 14 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,27 @@ function AbstractMCMC.step(
return transition, newstate
end

function AbstractMCMC.bundle_samples(
samples::Vector{<:Vector},
model::AbstractModel,
spl::Emcee,
state::EmceeState,
chain_type::Type{MCMCChains.Chains};
kwargs...,
)
n_walkers = _get_n_walkers(spl)
chains = map(1:n_walkers) do i
this_walker_samples = [s[i] for s in samples]
AbstractMCMC.bundle_samples(
this_walker_samples, model, spl, state, chain_type; kwargs...
# Have to define methods for both to avoid method ambiguities (as opposed to a single
# `::Type{T<:AbstractMCMC.AbstractChains})` since default `bundle_samples` takes
# `samples::AbstractVector`.
for Tchain in (:(MCMCChains.Chains), :(FlexiChains.VNChain))
@eval begin
function AbstractMCMC.bundle_samples(
samples::Vector{<:Vector},
model::DynamicPPL.Model,
spl::Emcee,
state::EmceeState,
::Type{$Tchain};
kwargs...,
)
n_walkers = _get_n_walkers(spl)
chains = map(1:n_walkers) do i
this_walker_samples = [s[i] for s in samples]
AbstractMCMC.bundle_samples(
this_walker_samples, model, spl, state, $Tchain; kwargs...
)
end
return AbstractMCMC.chainscat(chains...)
end
end
return AbstractMCMC.chainscat(chains...)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FlexiChains = "4a37a8b9-6e57-4b92-8664-298d46e639f7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
Loading
Loading