diff --git a/src/sample.jl b/src/sample.jl index dc951ca2..b949087f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -410,6 +410,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -466,10 +469,10 @@ function mcmcsample( # Return the new chain. return chain end - chains = if init_params === nothing + chains = if _init_params === nothing Distributed.pmap(sample_chain, pool, seeds) else - Distributed.pmap(sample_chain, pool, seeds, init_params) + Distributed.pmap(sample_chain, pool, seeds, _init_params) end finally # Stop updating the progress bar. @@ -499,6 +502,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -519,10 +525,10 @@ function mcmcsample( ) end - chains = if init_params === nothing + chains = if _init_params === nothing map(sample_chain, 1:nchains, seeds) else - map(sample_chain, 1:nchains, seeds, init_params) + map(sample_chain, 1:nchains, seeds, _init_params) end # Concatenate the chains together.