Skip to content

Allow for pymc native samplers to resume sampling from ZarrTrace #7687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

lucianopaz
Copy link
Member

@lucianopaz lucianopaz commented Feb 21, 2025

Description

Big PR approaching! This finishes adding the ability of pymc native step methods to resume sampling from an existing trace (as long as it's a ZarrTrace!). This means that you can now continue tuning or sampling from a pre-existing sample run. For example

with model:
    # First tuning run
    pm.sample(tune=400, draws=0, trace=trace)

    # Do whatever to decide if you want to continue tuning   
    pm.sample(tune=800, draws=0, trace=trace)

    # Switch to sampling
    pm.sample(tune=800, draws=1000, trace=trace)

Another thing is that the chunks_per_draw from ZarrTrace along with its persistent storage backends (like ZipStore or DirectoryStore) makes the sampling store the results and final sampling state periodically, so in case of a crash during sampling, you can use the existing store to load the trace using ZarrTrace.from_store and then resume sampling from there.

The only thing that I haven't tested for yet is to add an Op that makes pm.sample crash to see if I can reload the partial results from the store and resume sampling. @ricardoV94 gave me some pointers to that, but I won't be working on this for the rest of the month and I thought it best to open a draft PR to kick off any discussion you have or collect feedback

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7687.org.readthedocs.build/en/7687/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Feb 21, 2025
@lucianopaz lucianopaz changed the title Zarr continue Allow for pymc native samplers to resume sampling from ZarrTrace Feb 21, 2025
vars=trace_vars,
test_point=initial_point,
)
except TraceAlreadyInitialized:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just InitializedTrace? Seems a little verbose!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine to me, it's an internal thing

Comment on lines 1161 to 1169
if isinstance(trace, ZarrChain):
progress_manager.set_initial_state(*trace.completed_draws_and_divergences())
progress_manager._progress.update(
progress_manager.tasks[i],
draws=progress_manager.completed_draws
if progress_manager.combined_progress
else progress_manager.draws,
divergences=progress_manager.divergences,
refresh=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't like this abstraction leaking elsewhere, just provide a default to the Ndarray backend that makes it work for either method. In that case I suppose start everything at zero

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to improve the abstraction to prevent most of the custom ZarrChain checks

if isinstance(trace, ZarrChain):
trace.link_stepper(step)
stored_draw_idx = trace._sampling_state.draw_idx[chain]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here all this logic including the old link_stepper can have a sensible default in the base trace class so you don't need to worry about what kind of trace you have here. Just make link_stepper a no op and stored_draw_idx to be zero by default?

Comment on lines +201 to +211
if stored_draw_idx > 0:
if stored_sampling_state is not None:
self._step_method.sampling_state = stored_sampling_state
else:
raise RuntimeError(
"Cannot use the supplied ZarrTrace to restart sampling because "
"it has no sampling_state information stored. You will have to "
"resample from scratch."
)
draw = stored_draw_idx
self._write_point(trace.get_mcmc_point())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated logic, should be a property of the backend object?

@@ -491,6 +509,10 @@ def __init__(
progressbar=progressbar,
progressbar_theme=progressbar_theme,
)
if self.zarr_recording:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

abstraction leaking

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new functionality, I am deeply against all the if isinstance(..., ZarrTrace) in the codebase. Either our code is supposed to allow different trace backends or it is not, this suggests you want to drop the Ndarray altogether, which fine if you do.

Otherwise all these cases seem like they could be handled by the BaseTrace having sensible default for these methods. We used to have continuation of traces in the past with Ndarray, I don't see anything that fundamentally needs ZarrTrace other than dev interest in it? So just make it raise NotImplementedErrors or make them no-ops and adjust the external code appropriately

I stopped half-way so it was not an extensive review. I think this is a bigger design point that needs decision before settling on the details of the PR.

@lucianopaz lucianopaz marked this pull request as ready for review May 14, 2025 12:47
Copy link

codecov bot commented May 14, 2025

Codecov Report

❌ Patch coverage is 91.61426% with 40 lines in your changes missing coverage. Please review.
✅ Project coverage is 92.92%. Comparing base (0960323) to head (440ca46).

Files with missing lines Patch % Lines
pymc/sampling/parallel.py 50.00% 16 Missing ⚠️
pymc/backends/zarr.py 96.64% 9 Missing ⚠️
pymc/step_methods/state.py 88.46% 6 Missing ⚠️
pymc/backends/mcbackend.py 81.81% 2 Missing ⚠️
pymc/sampling/population.py 91.66% 2 Missing ⚠️
pymc/step_methods/metropolis.py 66.66% 2 Missing ⚠️
pymc/backends/base.py 94.11% 1 Missing ⚠️
pymc/backends/ndarray.py 90.00% 1 Missing ⚠️
pymc/sampling/mcmc.py 95.65% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7687      +/-   ##
==========================================
- Coverage   92.94%   92.92%   -0.03%     
==========================================
  Files         116      116              
  Lines       18851    19183     +332     
==========================================
+ Hits        17521    17825     +304     
- Misses       1330     1358      +28     
Files with missing lines Coverage Δ
pymc/backends/__init__.py 93.75% <100.00%> (+1.06%) ⬆️
pymc/progress_bar.py 93.63% <100.00%> (+0.25%) ⬆️
pymc/step_methods/compound.py 97.88% <100.00%> (+<0.01%) ⬆️
pymc/step_methods/hmc/base_hmc.py 92.30% <100.00%> (+0.05%) ⬆️
pymc/step_methods/hmc/quadpotential.py 84.69% <100.00%> (ø)
pymc/step_methods/step_sizes.py 80.95% <100.00%> (ø)
pymc/backends/base.py 89.06% <94.11%> (+0.37%) ⬆️
pymc/backends/ndarray.py 80.83% <90.00%> (+0.83%) ⬆️
pymc/sampling/mcmc.py 91.47% <95.65%> (+0.09%) ⬆️
pymc/backends/mcbackend.py 97.93% <81.81%> (-1.33%) ⬇️
... and 5 more
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@@ -201,6 +201,42 @@ def _slice(self, idx: slice) -> "IBaseTrace":
def point(self, idx: int) -> dict[str, np.ndarray]:
return self._chain.get_draws_at(idx, [v.name for v in self._chain.variables.values()])

def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[int, int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question why do we have special handling for divergences, instead of being sampler agnostic?

Comment on lines +1426 to +1423
# We only need to pass the traces when zarr_recording is happening because
# it's the only backend that can resume sampling
traces=traces if zarr_recording else None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Someone could implement a different trace outside of PyMC so why assume this?

@@ -194,6 +195,24 @@ def _start_loop(self):

draw = 0
tuning = True
if self._zarr_recording:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a method of the trace, still zarr specific

traces_send: Sequence[IBaseTrace] | bytes | None = None
if traces_pickled is not None:
traces_send = traces_pickled
elif traces is not None:
if mp_ctx.get_start_method() == "spawn":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this will be the default soon in linux as well (fork is being deprecated). In case we are not automating something we should be automating

return compatible_dataclass_values(self, other)


def resolve_typehint(hint: Any, anchor: object = None) -> Any:
Copy link
Member

@ricardoV94 ricardoV94 Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new functions seems like a big code smell. Can we approach it differently?

with pytest.raises(
AssertionError, match="The supplied state is incompatible with the current sampling state."
):
b1.sampling_state = b5.sampling_state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, I still dislike magic property. A method call is much more obvious that some checks may be done, and you're not just doing a random assignment. It also gives you a hatch to disable them if you're confident it's correct and want to avoid costly checks, because you're allowed to have kwargs then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you can't refactor properties into methods with back-compat. We had stuff with model.logp and it was a pita for saving two parenthesis back in the day

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look okay, I still have some reservations on code modularity and complexity. I don't like the extra work we're doing to validate states are compatible. If it's that hard just don't do it.

Some places still treat Zarr specially without much reason for it.

I'm okay with approving this next iteration / if you disagree, so this doesn't hang for ever

@@ -7,4 +7,5 @@ pytensor>=2.30.2,<2.31
rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
typeguard>=4.4.2,<5.0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs to be added to conda-envs as well

@schlich
Copy link

schlich commented Aug 4, 2025

Hello friends, I will be working on closing out this issue/PR today until it is done. I have questions for y'alls preferred review workflow:

  1. Should I open a PR against this branch? Or a separate one against main? Or other?
  2. I'm going to start with refactoring these tests to be a little bit more modular and structured. If i push these refactorings to the PR branch as I go, would you like to review them as they are pushed? Or would you like me to save all pushes until i'm done with the refactor and/or rework?f
  3. Assuming it's in scope for this PR, i'd like to complete the zarr v3 upgrade as it is completely blocking me from this feature for my setup. I plan to also incorporate appropriate error handling etc to accommodate v2, which should not be a problem since there are clear corollaries between the Store objects in each version for our purposes.

In the meantime I can push my changes to my own fork regardless, will share a link after first commit. Feel free to examine if needed.

Thanks in advance!

@lucianopaz
Copy link
Member Author

Thanks @schlich! I think that maybe I can rebase to fix the conflicts, merge the PR and you can take it from there with a fresh new PR on top of main. I don't think that there was much left to do to get this merged and I wont be able to make significant improvements before the end of August. zarr v3 support would be a great addition. Be sure to mention #7752 when you open your PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements feature request major Include in major changes release notes section request discussion samplers trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Add checkpoints during sampling
5 participants