Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix remaining method ambiguities #2304

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.33.3"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 5 additions & 3 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ end
# Extended in contrib/inference/abstractmcmc.jl
getstats(t) = nothing

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
abstract type AbstractTransition end

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
θ :: T
lp :: F # TODO: merge `lp` with `stat`
stat :: S
Expand Down Expand Up @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down Expand Up @@ -472,7 +474,7 @@ end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
SMC(space::Symbol...) = SMC(space)
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)

struct SMCTransition{T,F<:AbstractFloat}
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down Expand Up @@ -222,7 +222,7 @@ end

const CSMC = PG # type alias of PG as Conditional SMC

struct PGTransition{T,F<:AbstractFloat}
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function SGLD(
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real}
struct SGLDTransition{T,F<:Real} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample."
Expand Down
8 changes: 4 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved

Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
argument should be either a `Symbol` or a vector of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})

Check warning on line 286 in src/optimisation/Optimisation.jl

View check run for this annotation

Codecov / codecov/patch

src/optimisation/Optimisation.jl#L286

Added line #L286 was not covered by tests
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
Expand All @@ -304,7 +304,7 @@
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])

Check warning on line 307 in src/optimisation/Optimisation.jl

View check run for this annotation

Codecov / codecov/patch

src/optimisation/Optimisation.jl#L307

Added line #L307 was not covered by tests

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
5 changes: 3 additions & 2 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module AquaTests
using Aqua: Aqua
using Turing

# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
# in dependencies. Would like to check it for just Turing.jl itself though.
# We test ambiguities separately because it catches a lot of problems
# in dependencies but we test it for Turing.
Aqua.test_ambiguities([Turing])
Aqua.test_all(Turing; ambiguities=false)

end
Loading