Skip to content

Commit

Permalink
Write sampling state periodically
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Nov 7, 2024
1 parent 70f1cd2 commit 48504d5
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 13 deletions.
30 changes: 28 additions & 2 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Any,
Literal,
TypeAlias,
cast,
overload,
)

Expand All @@ -40,6 +41,7 @@
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
from zarr.storage import MemoryStore

import pymc as pm

Expand All @@ -50,7 +52,7 @@
find_observations,
)
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.backends.zarr import ZarrTrace
from pymc.backends.zarr import ZarrChain, ZarrTrace
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
Expand Down Expand Up @@ -1222,6 +1224,8 @@ def _iter_sample(
step.set_rng(rng)

point = start
if isinstance(trace, ZarrChain):
trace.link_stepper(step)

try:
step.tune = bool(tune)
Expand All @@ -1246,12 +1250,18 @@ def _iter_sample(

yield diverging
except KeyboardInterrupt:
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
trace.close()
raise
except BaseException:
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
trace.close()
raise
else:
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
trace.close()


Expand Down Expand Up @@ -1310,6 +1320,19 @@ def _mp_sample(

# We did draws += tune in pm.sample
draws -= tune
zarr_chains: list[ZarrChain] | None = None
zarr_recording = False
if all(isinstance(trace, ZarrChain) for trace in traces):
if isinstance(cast(ZarrChain, traces[0])._posterior.store, MemoryStore):
warnings.warn(
"Parallel sampling with MemoryStore zarr store wont write the processes "
"step method sampling state. If you wish to be able to access the step "
"method sampling state, please use a different storage backend, e.g. "
"DirectoryStore or ZipStore"
)
else:
zarr_chains = cast(list[ZarrChain], traces)
zarr_recording = True

sampler = ps.ParallelSampler(
draws=draws,
Expand All @@ -1323,13 +1346,16 @@ def _mp_sample(
progressbar_theme=progressbar_theme,
blas_cores=blas_cores,
mp_ctx=mp_ctx,
zarr_chains=zarr_chains,
)
try:
try:
with sampler:
for draw in sampler:
strace = traces[draw.chain]
strace.record(draw.point, draw.stats)
if not zarr_recording:
# Zarr recording happens in each process
strace.record(draw.point, draw.stats)
log_warning_stats(draw.stats)
if draw.is_last:
strace.close()
Expand Down
47 changes: 47 additions & 0 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from collections import namedtuple
from collections.abc import Sequence
from typing import cast

import cloudpickle
import numpy as np
Expand All @@ -31,6 +32,7 @@
from rich.theme import Theme
from threadpoolctl import threadpool_limits

from pymc.backends.zarr import ZarrChain
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.util import CustomProgress, default_progress_theme
Expand Down Expand Up @@ -100,6 +102,9 @@ def __init__(
rng_state,
seed_seq: np.random.SeedSequence,
blas_cores,
chain: int,
zarr_chains: list[ZarrChain] | bytes | None = None,
zarr_chains_is_pickled: bool = False,
):
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
Expand All @@ -108,6 +113,15 @@ def __init__(
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self.chain = chain
self._zarr_recording = False
self._zarr_chain: ZarrChain | None = None
if zarr_chains_is_pickled:
self._zarr_chain = cloudpickle.loads(zarr_chains)[self.chain]
elif zarr_chains is not None:
self._zarr_chain = cast(list[ZarrChain], zarr_chains)[self.chain]
self._zarr_recording = self._zarr_chain is not None

self._shared_point = shared_point
self._rng = rng
self._draws = draws
Expand All @@ -132,6 +146,7 @@ def run(self):
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._link_step_to_zarrchain()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
Expand All @@ -145,6 +160,10 @@ def run(self):
finally:
self._msg_pipe.close()

def _link_step_to_zarrchain(self):
if self._zarr_recording:
self._zarr_chain.link_stepper(self._step_method)

def _wait_for_abortion(self):
while True:
msg = self._recv_msg()
Expand All @@ -167,6 +186,7 @@ def _recv_msg(self):
return self._msg_pipe.recv()

def _start_loop(self):
zarr_recording = self._zarr_recording
self._step_method.set_rng(self._rng)

draw = 0
Expand Down Expand Up @@ -196,6 +216,8 @@ def _start_loop(self):
if msg[0] == "abort":
raise KeyboardInterrupt()
elif msg[0] == "write_next":
if zarr_recording:
self._zarr_chain.record(point, stats)
self._write_point(point)
is_last = draw + 1 == self._draws + self._tune
self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats))
Expand All @@ -222,6 +244,8 @@ def __init__(
start: dict[str, np.ndarray],
blas_cores,
mp_ctx,
zarr_chains: list[ZarrChain] | None = None,
zarr_chains_pickled: bytes | None = None,
):
self.chain = chain
process_name = f"worker_chain_{chain}"
Expand All @@ -247,6 +271,16 @@ def __init__(
self._readable = True
self._num_samples = 0

zarr_chains_send: list[ZarrChain] | bytes | None = None
if zarr_chains_pickled is not None:
zarr_chains_send = zarr_chains_pickled
elif zarr_chains is not None:
if mp_ctx.get_start_method() == "spawn":
raise ValueError(
"please provide a pre-pickled zarr_chains when multiprocessing start method is 'spawn'"
)
zarr_chains_send = zarr_chains

if step_method_pickled is not None:
step_method_send = step_method_pickled
else:
Expand All @@ -272,6 +306,9 @@ def __init__(
rng.bit_generator.state,
rng.bit_generator.seed_seq,
blas_cores,
self.chain,
zarr_chains_send,
zarr_chains_pickled is not None,
),
)
self._process.start()
Expand Down Expand Up @@ -394,6 +431,7 @@ def __init__(
progressbar_theme: Theme | None = default_progress_theme,
blas_cores: int | None = None,
mp_ctx=None,
zarr_chains: list[ZarrChain] | None = None,
):
if any(len(arg) != chains for arg in [rngs, start_points]):
raise ValueError(f"Number of rngs and start_points must be {chains}.")
Expand All @@ -414,8 +452,15 @@ def __init__(
mp_ctx = multiprocessing.get_context(mp_ctx)

step_method_pickled = None
zarr_chains_pickled = None
self.zarr_recording = False
if zarr_chains is not None:
assert all(isinstance(zarr_chain, ZarrChain) for zarr_chain in zarr_chains)
self.zarr_recording = True
if mp_ctx.get_start_method() != "fork":
step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)
if zarr_chains is not None:
zarr_chains_pickled = cloudpickle.dumps(zarr_chains, protocol=-1)

self._samplers = [
ProcessAdapter(
Expand All @@ -428,6 +473,8 @@ def __init__(
start,
blas_cores,
mp_ctx,
zarr_chains=zarr_chains,
zarr_chains_pickled=zarr_chains_pickled,
)
for chain, rng, start in zip(range(chains), rngs, start_points)
]
Expand Down
90 changes: 79 additions & 11 deletions tests/backends/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import pytest
import xarray as xr
import zarr

from arviz import InferenceData
Expand Down Expand Up @@ -62,9 +63,9 @@ def model():
return model


@pytest.fixture(params=[True, False])
@pytest.fixture(params=["include_transformed", "discard_transformed"])
def include_transformed(request):
return request.param
return request.param == "include_transformed"


@pytest.fixture(params=["frequent_writes", "sparse_writes"])
Expand Down Expand Up @@ -94,7 +95,7 @@ def model_step(request, model):


def test_record(model, model_step, include_transformed, draws_per_chunk):
store = zarr.MemoryStore()
store = zarr.TempStore()
trace = ZarrTrace(
store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
)
Expand Down Expand Up @@ -353,27 +354,31 @@ def test_split_warmup(tune, model, model_step, include_transformed):
assert trace.root["warmup_sample_stats"][var_name].shape[1] == tune


@pytest.fixture(scope="function", params=[True, False])
@pytest.fixture(scope="function", params=["discard_tuning", "keep_tuning"])
def discard_tuned_samples(request):
return request.param
return request.param == "discard_tuning"


@pytest.fixture(scope="function", params=[True, False])
@pytest.fixture(scope="function", params=["return_idata", "return_zarr"])
def return_inferencedata(request):
return request.param
return request.param == "return_idata"


@pytest.fixture(scope="function", params=[True, False])
@pytest.fixture(
scope="function", params=[True, False], ids=["keep_warning_stat", "discard_warning_stat"]
)
def keep_warning_stat(request):
return request.param


@pytest.fixture(scope="function", params=[True, False])
@pytest.fixture(
scope="function", params=[True, False], ids=["parallel_sampling", "sequential_sampling"]
)
def parallel(request):
return request.param


@pytest.fixture(scope="function", params=[True, False])
@pytest.fixture(scope="function", params=[True, False], ids=["compute_loglike", "no_loglike"])
def log_likelihood(request):
return request.param

Expand All @@ -393,7 +398,7 @@ def test_sample(
pytest.skip(
reason="log_likelihood is only computed if an inference data object is returned"
)
store = zarr.MemoryStore()
store = zarr.TempStore()
trace = ZarrTrace(
store=store, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
)
Expand Down Expand Up @@ -460,3 +465,66 @@ def test_sample(
for name, v in out_trace.posterior.arrays()
if name not in dimensions
)

# Assert that the trace has valid sampling state stored for each chain
for step_method_state in trace._sampling_state.sampling_state[:]:
# We have no access to the actual step method that was using by each chain in pymc.sample
# The best way to see if the step method state is valid is by trying to set
# the model_step sampling state to the one stored in the trace.
model_step.sampling_state = step_method_state


def test_sampling_consistency(
model,
model_step,
draws_per_chunk,
):
# Test that pm.sample will generate the same posterior and sampling state
# regardless of whether sampling was done in parallel or not.
store1 = zarr.TempStore()
parallel_trace = ZarrTrace(
store=store1, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
)
store2 = zarr.TempStore()
sequential_trace = ZarrTrace(
store=store2, include_transformed=include_transformed, draws_per_chunk=draws_per_chunk
)
tune = 2
draws = 3
chains = 2
random_seed = 12345
initial_step_state = model_step.sampling_state
with model:
parallel_idata = pm.sample(
draws=draws,
tune=tune,
chains=chains,
cores=chains,
trace=parallel_trace,
step=model_step,
discard_tuned_samples=True,
return_inferencedata=True,
keep_warning_stat=False,
idata_kwargs={"log_likelihood": False},
random_seed=random_seed,
)
model_step.sampling_state = initial_step_state
sequential_idata = pm.sample(
draws=draws,
tune=tune,
chains=chains,
cores=1,
trace=sequential_trace,
step=model_step,
discard_tuned_samples=True,
return_inferencedata=True,
keep_warning_stat=False,
idata_kwargs={"log_likelihood": False},
random_seed=random_seed,
)
for chain in range(chains):
assert equal_sampling_states(
parallel_trace._sampling_state.sampling_state[chain],
sequential_trace._sampling_state.sampling_state[chain],
)
xr.testing.assert_equal(parallel_idata.posterior, sequential_idata.posterior)

0 comments on commit 48504d5

Please sign in to comment.