Skip to content

Commit

Permalink
use _init_parmas for MCMCThreads and MCMCDistributed too
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde authored Sep 13, 2023
1 parent d7c549f commit 6f5ac5a
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -466,10 +469,10 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if init_params === nothing
chains = if _init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
Distributed.pmap(sample_chain, pool, seeds, _init_params)
end
finally
# Stop updating the progress bar.
Expand Down Expand Up @@ -499,6 +502,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand All @@ -519,10 +525,10 @@ function mcmcsample(
)
end

chains = if init_params === nothing
chains = if _init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
map(sample_chain, 1:nchains, seeds, _init_params)
end

# Concatenate the chains together.
Expand Down

0 comments on commit 6f5ac5a

Please sign in to comment.