Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of step_warmup #117

Merged
merged 42 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
0987f5f
added step_warmup which is can be overloaded when convenient
torfjelde Mar 9, 2023
30c9f12
added step_warmup to docs
torfjelde Mar 9, 2023
7faa73f
Update src/interface.jl
torfjelde Mar 9, 2023
bd0bdc7
introduce new kwarg `num_warmup` to `sample` which uses `step_warmup`
torfjelde Mar 10, 2023
c620cca
updated docs
torfjelde Mar 10, 2023
572a286
allow combination of discard_initial and num_warmup
torfjelde Mar 10, 2023
6b842ee
added docstring for mcmcsample
torfjelde Mar 10, 2023
ca03832
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
0441773
Apply suggestions from code review
torfjelde Mar 10, 2023
ea369ff
Apply suggestions from code review
torfjelde Mar 10, 2023
8e0ca53
Update src/sample.jl
torfjelde Mar 10, 2023
6877978
removed docstring and deferred description of keyword arguments to th…
torfjelde Mar 10, 2023
b3b3148
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ddc5254
Update src/sample.jl
torfjelde Mar 10, 2023
ffbd32f
Update src/sample.jl
torfjelde Mar 10, 2023
87480ff
added num_warmup to common keyword arguments docs
torfjelde Mar 10, 2023
76f2f23
also allow step_warmup for the initial step
torfjelde Mar 10, 2023
c00d0c9
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ef09c19
simplify logic for discarding fffinitial samples
torfjelde Mar 10, 2023
49b8406
Apply suggestions from code review
torfjelde Mar 10, 2023
f005746
also report progress for the discarded samples
torfjelde Mar 10, 2023
9dccd8a
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ff00e6e
Apply suggestions from code review
torfjelde Mar 10, 2023
7ce9f6b
move progress-report to end of for-loop for discard samples
torfjelde Mar 10, 2023
3a217b2
move step_warmup to the inner while loops too
torfjelde Mar 13, 2023
de9bb2c
Update src/sample.jl
torfjelde Mar 13, 2023
85d938f
Apply suggestions from code review
torfjelde Apr 19, 2023
0a667a4
reverted to for-loop
torfjelde Apr 19, 2023
91f5a10
Update src/sample.jl
torfjelde Apr 19, 2023
7603171
added accidentanly removed comment
torfjelde Apr 19, 2023
ef68d04
Update src/sample.jl
torfjelde Apr 19, 2023
25afc66
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 24, 2023
1886fa8
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 25, 2023
0ea293a
fixed formatting
torfjelde Oct 26, 2023
6e8f88e
fix typo
torfjelde Oct 26, 2023
44c55bb
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 4, 2024
3b4f6db
Apply suggestions from code review
torfjelde Oct 4, 2024
f9142a6
Added testing of warmup steps
torfjelde Oct 4, 2024
295fdc1
Added checks as @devmotion requested
torfjelde Oct 4, 2024
e6acb1f
Removed unintended change in previous commit
torfjelde Oct 4, 2024
2e9fa5c
Bumped patch version
torfjelde Oct 4, 2024
366fceb
Bump minor version instead of patch version since this is a new feature
torfjelde Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.4.0"
version = "4.4.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ the sampling step of the inference method.
AbstractMCMC.step
```

If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading

```@docs
AbstractMCMC.step_warmup
```

which will be used for the first `num_warmup`, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref).
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above.

## Collecting samples

!!! note
Expand Down
16 changes: 16 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ current `state` of the sampler.
"""
function step end

"""
step_warmup(rng, model, sampler[, state; kwargs...])

Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`.

When sampling using [`sample`](@ref), this takes the place of [`step`](@ref) in the first
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref).
This is useful if the sampler has a "warmup"-stage initial stage that is different from the
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
standard iteration.

By default, this simply calls [`step`](@ref.)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
"""
function step_warmup(rng, model, sampler, state; kwargs...)
return step(rng, model, sampler, state; kwargs...)
end

"""
samples(sample, model, sampler[, N; kwargs...])

Expand Down
174 changes: 145 additions & 29 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,29 @@ function StatsBase.sample(
end

# Default implementations of regular and parallel sampling.

"""
mcmcsample(rng, model, sampler, N_or_is_done; kwargs...)

Default implementation of `sample` for a `model` and `sampler`.

# Arguments
- `rng::Random.AbstractRNG`: the random number generator to use.
- `model::AbstractModel`: the model to sample from.
- `sampler::AbstractSampler`: the sampler to use.
- `N::Integer`: the number of samples to draw.

# Keyword arguments
- `progress`: whether to display a progress bar. Defaults to `true`.
- `progressname`: the name of the progress bar. Defaults to `"Sampling"`.
- `callback`: a function that is called after each [`AbstractMCMC.step`](@ref).
Defaults to `nothing`.
- `num_warmup`: number of warmup samples to draw. Defaults to `0`.
- `discard_initial`: number of initial samples to discard. Defaults to `num_warmup`.
- `thinning`: number of samples to discard between samples. Defaults to `1`.
- `chain_type`: the type to pass to [`AbstractMCMC.bundle_samples`](@ref) at the
end of sampling to wrap up the resulting samples nicely. Defaults to `Any`.
- `kwargs...`: Additional keyword arguments to pass on to [`AbstractMCMC.step`](@ref).
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nobody will look up the docstring for the unexported mcmcsample function, so it feels listing and explaining keyword arguments in https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments is the better approach? And possibly extending the docstring of sample?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaah I was totally unaware!

So I removed this, and then I've just added a section to sample to tell people where to find docs on the default arguments. I personally rarely go to the docs of a package unless I "have" to, so I think it's at least nice to tell the user where to find the info. I'm even partial to putting the stuff about common keywords in the actual docstrings of sample but I'll leave as is for now.

function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
Expand All @@ -100,14 +122,24 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Sampling",
callback=nothing,
discard_initial=0,
num_warmup::Int=0,
discard_initial::Int=num_warmup,
thinning=1,
chain_type::Type=Any,
kwargs...,
)
# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
discard_initial >= 0 || throw(ArgumentError("number of discarded samples must be non-negative"))
num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative"))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
Ntotal = thinning * (N - 1) + discard_initial + 1
Ntotal >= num_warmup || throw(ArgumentError("number of warm-up samples exceeds the total number of samples"))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

# Determine how many samples to drop from `num_warmup` and the
# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup
discard_from_sample = max(discard_initial - discard_from_warmup, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this seems to match what I wrote above.


# Start the timer
start = time()
Expand All @@ -124,34 +156,76 @@ function mcmcsample(
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

# Discard initial samples.
for i in 1:discard_initial
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
end

# Warmup sampling.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
for _ in 1:discard_from_warmup
# Obtain the next sample and state.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be accounted for in the progress logger as well (as done currently).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good now 👍

sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = step_warmup(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
i = 1
if keep_from_warmup > 0
# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, i, model, sampler; kwargs...)

# Step through remainder of warmup iterations and save.
i += 1
for _ in (discard_from_warmup + 1):num_warmup
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
end

# Obtain the next sample and state.
sample, state = step_warmup(rng, model, sampler, state; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = save!!(samples, sample, i, model, sampler; kwargs...)
i += 1
end
else
# Discard additional initial samples, if needed.
for _ in 1:discard_from_sample
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, i, model, sampler; kwargs...)

# Increment iteration number.
i += 1
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be easier to do something along the lines of

        # Discard initial samples, if needed.
        for _ in 1:discard_initial
            # Update the progress bar.
            if progress && i >= next_update
                ProgressLogging.@logprogress i / Ntotal
                next_update = i + threshold
            end

            # Obtain the next sample and state.
            sample, state = if i <= num_warmup
                step_warmup(rng, model, sampler, state; kwargs...)
            else
                step(rng, model, sampler, state; kwargs...)
            end
        end

        ...

        # Increment iteration number.
        i += 1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well... Yes 🤦

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually, I now remember the reason for why I didn't do this: won't this make it type-unstable while the current implementation won't (if indeed step_warmup and step return the same type)?

Whether we care is another thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If step_warmup and step return the same types, both approaches should be type-stable, no? And if not, you're in trouble in both cases I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My worry was that too many if-staments involving runtime information could confuse the type-inference, even in the case where the return-types of each branch is the same, but from your comments I'm assuming this was an irrational fear:)

Also, I just read a bit more https://juliahub.com/blog/2016/04/inference-convergence/ and my fears are indeed irrational 👍 I guess as long as each of the if-statement returns the same two types, i.e. the Union doesn't change, we're fine. IIRC Julia does union-splitting up to unions of length 4.


# Update the progress bar.
itotal = 1 + discard_initial
itotal = i
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

# Step through the sampler.
for i in 2:N
while i N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to switch to a while loop here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no! I'll revert it to for-loop 👍

# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
Expand All @@ -174,6 +248,9 @@ function mcmcsample(
# Save the sample.
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Increment iteration counter.
i += 1

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
Expand Down Expand Up @@ -209,10 +286,16 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Convergence sampling",
callback=nothing,
discard_initial=0,
num_warmup=0,
discard_initial=num_warmup,
thinning=1,
kwargs...,
)
# Determine how many samples to drop from `num_warmup` and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the same/similar error checks as above?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup
discard_from_sample = max(discard_initial - discard_from_warmup, 0)

# Start the timer
start = time()
Expand All @@ -222,21 +305,54 @@ function mcmcsample(
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for _ in 1:discard_initial
# Warmup sampling.
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
for _ in 1:discard_from_warmup
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = step_warmup(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
i = 1
if keep_from_warmup > 0
# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, i, model, sampler; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, 1, model, sampler; kwargs...)
# Step through remainder of warmup iterations and save.
i += 1
for _ in (discard_from_warmup + 1):num_warmup
# Obtain the next sample and state.
sample, state = step_warmup(rng, model, sampler, state; kwargs...)

# Step through the sampler until stopping.
i = 2
# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = save!!(samples, sample, i, model, sampler; kwargs...)
i += 1
end
else
# Discard additional initial samples, if needed.
for _ in 1:discard_from_sample
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, i, model, sampler; kwargs...)

# Increment iteration number.
i += 1
end

while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...)
# Discard thinned samples.
Expand Down