Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

infer named parameter layout from dict-type initial_state #463

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 88 additions & 69 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from itertools import count
from typing import Dict, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Sequence, Union

import numpy as np

Expand All @@ -15,11 +15,45 @@

__all__ = ["EnsembleSampler", "walkers_independent"]

try:
from collections.abc import Iterable
except ImportError:
# for py2.7, will be an Exception in 3.8
from collections import Iterable
ParameterNamesT = Union[
Sequence[str], Dict[str, Union[slice, int, Sequence[int]]]
]


def infer_dict_mapping(state):
i0 = 0
param_slice_shape = {}
for key, val in state.items():
val = np.asarray(val)
i1 = i0 + val.size
slc = slice(i0, i1) if val.size > 1 else i0
param_slice_shape[key] = slc, val.shape
i0 = i1

return param_slice_shape


def array_to_dict(ary, param_slice_shape):
return {
key: ary[:, slc].reshape((-1,) + shape)
for key, (slc, shape) in param_slice_shape.items()
}


def array_to_list_of_dicts(ary, param_slice_shape):
# reshape adds a small amount of overhead; don't do it unless necessary
return [
{
key: ary_i[slc].reshape(shape) if len(shape) > 1 else ary_i[slc]
for key, (slc, shape) in param_slice_shape.items()
}
for ary_i in ary
]


def collapse_and_hstack(values, nwalkers=None):
shape = (nwalkers, -1) if nwalkers is not None else -1
return np.hstack([np.asarray(val).reshape(shape) for val in values])


class EnsembleSampler(object):
Expand Down Expand Up @@ -62,7 +96,8 @@ class EnsembleSampler(object):
to accept a list of position vectors instead of just one. Note
that ``pool`` will be ignored if this is ``True``.
(default: ``False``)
parameter_names (Optional[Union[List[str], Dict[str, List[int]]]]):
parameter_names (Union[Sequence[str],
Dict[str, Union[slice, int, Sequence[int]]]):
names of individual parameters or groups of parameters. If
specified, the ``log_prob_fn`` will recieve a dictionary of
parameters, rather than a ``np.ndarray``.
Expand All @@ -81,7 +116,7 @@ def __init__(
backend=None,
vectorize=False,
blobs_dtype=None,
parameter_names: Optional[Union[Dict[str, int], List[str]]] = None,
parameter_names: Optional[ParameterNamesT] = None,
# Deprecated...
a=None,
postargs=None,
Expand Down Expand Up @@ -163,48 +198,39 @@ def __init__(
# ``args`` and ``kwargs`` pickleable.
self.log_prob_fn = _FunctionWrapper(log_prob_fn, args, kwargs)

# Save the parameter names
self.params_are_named: bool = parameter_names is not None
if self.params_are_named:
assert isinstance(parameter_names, (list, dict))

# Don't support vectorizing yet
msg = "named parameters with vectorization unsupported for now"
assert not self.vectorize, msg

# Check for duplicate names
dupes = set()
uniq = []
for name in parameter_names:
if name not in dupes:
uniq.append(name)
dupes.add(name)
msg = f"duplicate paramters: {dupes}"
assert len(uniq) == len(parameter_names), msg

if isinstance(parameter_names, list):
# Check for all named
msg = "name all parameters or set `parameter_names` to `None`"
assert len(parameter_names) == ndim, msg
# Convert a list to a dict
parameter_names: Dict[str, int] = {
name: i for i, name in enumerate(parameter_names)
if parameter_names is not None:
if isinstance(parameter_names, Sequence):
if len(parameter_names) != ndim:
raise ValueError(
f"`parameter_names` does not specify {ndim} names"
)
parameter_names = dict(zip(parameter_names, range(ndim)))

indices = np.arange(ndim)

try:
index_map = {
key: indices[slc] for key, slc in parameter_names.items()
}
indexed = collapse_and_hstack(index_map.values())
except IndexError as err:
msg = "`parameter_names` specifies out-of-bounds element(s)"
raise ValueError(msg) from err

if len(indexed) != ndim:
raise ValueError(
"`parameter_names` does not specify indices for"
f" {ndim} parameters"
)
if set(indexed) != set(indices):
raise ValueError(
"`parameter_names` does not specify indices"
f" 0 through {ndim-1}"
)

# Check not too many names
msg = "too many names"
assert len(parameter_names) <= ndim, msg

# Check all indices appear
values = [
v if isinstance(v, list) else [v]
for v in parameter_names.values()
]
values = [item for sublist in values for item in sublist]
values = set(values)
msg = f"not all values appear -- set should be 0 to {ndim-1}"
assert values == set(np.arange(ndim)), msg
self.parameter_names = parameter_names
self.param_slice_shape = infer_dict_mapping(index_map)
else:
self.param_slice_shape = None

@property
def random_state(self):
Expand Down Expand Up @@ -266,7 +292,8 @@ def sample(
"""Advance the chain as a generator

Args:
initial_state (State or ndarray[nwalkers, ndim]): The initial
initial_state (State or ndarray[nwalkers, ndim] or
dict[str, float | np.ndarray[nwalkers. ...]]): The initial
:class:`State` or positions of the walkers in the
parameter space.
iterations (Optional[int or NoneType]): The number of steps to generate.
Expand Down Expand Up @@ -302,6 +329,13 @@ def sample(
if iterations is None and store:
raise ValueError("'store' must be False when 'iterations' is None")
# Interpret the input as a walker state and check the dimensions.
if isinstance(initial_state, dict):
_state = {key: val[0] for key, val in initial_state.items()}
self.param_slice_shape = infer_dict_mapping(_state)
initial_state = collapse_and_hstack(
initial_state.values(), self.nwalkers
)

state = State(initial_state, copy=True)
state_shape = np.shape(state.coords)
if state_shape != (self.nwalkers, self.ndim):
Expand Down Expand Up @@ -472,8 +506,11 @@ def compute_log_prob(self, coords):
raise ValueError("At least one parameter value was NaN")

# If the parmaeters are named, then switch to dictionaries
if self.params_are_named:
p = ndarray_to_list_of_dicts(p, self.parameter_names)
if self.param_slice_shape:
if self.vectorize:
p = array_to_dict(p, self.param_slice_shape)
else:
p = array_to_list_of_dicts(p, self.param_slice_shape)

# Run the log-probability calculations (optionally in parallel).
if self.vectorize:
Expand Down Expand Up @@ -664,21 +701,3 @@ def _scaled_cond(a):
return np.inf
c = b / bsum
return np.linalg.cond(c.astype(float))


def ndarray_to_list_of_dicts(
x: np.ndarray, key_map: Dict[str, Union[int, List[int]]]
) -> List[Dict[str, Union[np.number, np.ndarray]]]:
"""
A helper function to convert a ``np.ndarray`` into a list
of dictionaries of parameters. Used when parameters are named.

Args:
x (np.ndarray): parameter array of shape ``(N, n_dim)``, where
``N`` is an integer
key_map (Dict[str, Union[int, List[int]]):

Returns:
list of dictionaries of parameters
"""
return [{key: xi[val] for key, val in key_map.items()} for xi in x]