Skip to content

Commit

Permalink
Add ZarrTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Oct 16, 2024
1 parent 5352798 commit 7dcce58
Show file tree
Hide file tree
Showing 4 changed files with 463 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ jobs:
- |
tests/backends/test_mcbackend.py
tests/backends/test_zarr.py
tests/distributions/test_truncated.py
tests/logprob/test_abstract.py
tests/logprob/test_basic.py
Expand Down Expand Up @@ -284,6 +285,7 @@ jobs:
- |
tests/backends/test_arviz.py
tests/backends/test_zarr.py
tests/variational/test_updates.py
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
301 changes: 301 additions & 0 deletions pymc/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any

import numcodecs
import numpy as np
import zarr

from pytensor.tensor.variable import TensorVariable
from zarr.storage import BaseStore
from zarr.sync import Synchronizer

from pymc.backends.arviz import (
coords_and_dims_for_inferencedata,
find_constants,
find_observations,
)
from pymc.backends.base import BaseTrace
from pymc.model.core import Model, modelcontext
from pymc.step_methods.compound import (
BlockedStep,
CompoundStep,
StatsBijection,
get_stats_dtypes_shapes_from_steps,
)
from pymc.util import get_default_varnames


class ZarrChain(BaseTrace):
def __init__(
self,
store: BaseStore | MutableMapping,
stats_bijection: StatsBijection,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
test_point: Sequence[dict[str, np.ndarray]] | None = None,
):
super().__init__(name="zarr", model=model, vars=vars, test_point=test_point)
self.draw_idx = 0
self._posterior = zarr.open_group(
store, synchronizer=synchronizer, path="posterior", mode="a"
)
self._sample_stats = zarr.open_group(
store, synchronizer=synchronizer, path="sample_stats", mode="a"
)
self._sampling_state = zarr.open_group(
store, synchronizer=synchronizer, path="_sampling_state", mode="a"
)
self.stats_bijection = stats_bijection

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override]
self.chain = chain

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
chain = self.chain
draw_idx = self.draw_idx
for var_name, var_value in zip(self.varnames, self.fn(draw)):
self._posterior[var_name].set_orthogonal_selection(
(chain, draw_idx),
var_value,
)
for var_name, var_value in self.stats_bijection.map(stats).items():
self._sample_stats[var_name].set_orthogonal_selection(
(chain, draw_idx),
var_value,
)
self.draw_idx += 1

def record_sampling_state(self, step):
self._sampling_state.sampling_state.set_coordinate_selection(
self.chain, np.array([step.sampling_state], dtype="object")
)
self._sampling_state.draw_idx.set_coordinate_selection(self.chain, self.draw_idx)


FILL_VALUE_TYPE = float | int | bool | str | np.datetime64 | np.timedelta64 | None


def get_fill_value_and_codec(
dtype: Any,
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)

Check warning on line 103 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L103

Added line #L103 was not covered by tests
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)

Check warning on line 105 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L105

Added line #L105 was not covered by tests
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)

Check warning on line 107 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L107

Added line #L107 was not covered by tests
else:
return (None, _dtype, numcodecs.Pickle())


class ZarrTrace:
def __init__(
self,
store: BaseStore | MutableMapping | None = None,
synchronizer: Synchronizer | None = None,
model: Model | None = None,
vars: Sequence[TensorVariable] | None = None,
include_transformed: bool = False,
):
model = modelcontext(model)
self.model = model

self.synchronizer = synchronizer
self.root = zarr.group(
store=store,
overwrite=True,
synchronizer=synchronizer,
)
self.coords, self.vars_to_dims = coords_and_dims_for_inferencedata(model)

if vars is None:
vars = model.unobserved_value_vars

unnamed_vars = {var for var in vars if var.name is None}
if unnamed_vars:
raise Exception(f"Can't trace unnamed variables: {unnamed_vars}")

Check warning on line 137 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L137

Added line #L137 was not covered by tests
self.varnames = get_default_varnames(
[var.name for var in vars], include_transformed=include_transformed
)
self.vars = [var for var in vars if var.name in self.varnames]

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")

# Get variable shapes. Most backends will need this
# information.
test_point = model.initial_point()
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_dtype_shapes = {var: (value.dtype, value.shape) for var, value in var_values}
self._is_base_setup = False

@property
def posterior(self):
return self.root.posterior

@property
def sample_stats(self):
return self.root.sample_stats

@property
def constant_data(self):
return self.root.constant_data

@property
def observed_data(self):
return self.root.observed_data

@property
def _sampling_state(self):
return self.root._sampling_state

def init_trace(self, chains: int, draws: int, step: BlockedStep | CompoundStep):
self.create_group(
name="constant_data",
data_dict=find_constants(self.model),
)

self.create_group(
name="observed_data",
data_dict=find_observations(self.model),
)

self.init_group_with_empty(
group=self.root.create_group(name="posterior", overwrite=True),
var_dtype_and_shape=self.var_dtype_shapes,
chains=chains,
draws=draws,
)
stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps(
[step] if isinstance(step, BlockedStep) else step.methods
)
self.init_group_with_empty(
group=self.root.create_group(name="sample_stats", overwrite=True),
var_dtype_and_shape=stats_dtypes_shapes,
chains=chains,
draws=draws,
)

self.init_sampling_state_group(chains=chains)

self.straces = [
ZarrChain(
store=self.root.store,
synchronizer=self.synchronizer,
model=self.model,
vars=self.vars,
test_point=None,
stats_bijection=StatsBijection(step.stats_dtypes),
)
]
for chain, strace in enumerate(self.straces):
strace.setup(draws=draws, chain=chain, sampler_vars=None)

def consolidate(self):
self.root = zarr.consolidate_metadata(self.root.store)

Check warning on line 215 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L215

Added line #L215 was not covered by tests

def close(self):
self.consolidate()
self.root.store.close()

Check warning on line 219 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L218-L219

Added lines #L218 - L219 were not covered by tests

def init_sampling_state_group(self, chains):
state = self.root.create_group(name="_sampling_state", overwrite=True)
sampling_state = state.empty(
name="sampling_state",
overwrite=True,
shape=(chains,),
chunks=(1,),
dtype="object",
object_codec=numcodecs.Pickle(),
)
sampling_state.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
draw_idx = state.array(
name="draw_idx",
overwrite=True,
data=np.zeros(chains, dtype="int"),
chunks=(1,),
dtype="int",
fill_value=-1,
)
draw_idx.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})
chain = state.array(name="chain", data=range(chains))
chain.attrs.update({"_ARRAY_DIMENSIONS": ["chain"]})

def init_group_with_empty(self, group, var_dtype_and_shape, chains, draws):
group_coords = {"chain": range(chains), "draw": range(draws)}
for name, (dtype, shape) in var_dtype_and_shape.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(dtype)
shape = shape or ()
array = group.full(
name=name,
dtype=dtype,
fill_value=fill_value,
object_codec=object_codec,
shape=(chains, draws, *shape),
chunks=(1, 1, *shape),
)
try:
dims = self.vars_to_dims[name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i, shape_i in enumerate(shape):
dim = f"{name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(shape_i))
dims = ("chain", "draw", *dims)
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group

def create_group(self, name, data_dict):
if data_dict:
group_coords = {}
group = self.root.create_group(name=name, overwrite=True)
for var_name, var_value in data_dict.items():
fill_value, dtype, object_codec = get_fill_value_and_codec(var_value.dtype)
array = group.array(
name=var_name,
data=var_value,
fill_value=fill_value,
dtype=dtype,
object_codec=object_codec,
)
try:
dims = self.vars_to_dims[var_name]
for dim in dims:
group_coords[dim] = self.coords[dim]
except KeyError:
dims = []
for i in range(var_value.ndim):
dim = f"{var_name}_dim_{i}"
dims.append(dim)
group_coords[dim] = list(range(var_value.shape[i]))

Check warning on line 296 in pymc/backends/zarr.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/zarr.py#L291-L296

Added lines #L291 - L296 were not covered by tests
array.attrs.update({"_ARRAY_DIMENSIONS": dims})
for dim, coord in group_coords.items():
array = group.array(name=dim, data=coord, fill_value=None)
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
return group
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ rich>=13.7.1
scipy>=1.4.1
threadpoolctl>=3.1.0,<4.0.0
typing-extensions>=3.7.4
zarr>=2.5.0,<3
Loading

0 comments on commit 7dcce58

Please sign in to comment.