Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: use Generator instead of RandomState #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 28, 2024
commit 6cafae1b736d81f1386d52a4158c62dc8425da0a
7 changes: 3 additions & 4 deletions src/emcee/backends/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from __future__ import division, print_function

import json
import os
from tempfile import NamedTemporaryFile
import json

import numpy as np

Expand Down Expand Up @@ -202,7 +202,7 @@ def accepted(self):
def random_state(self):
with self.open() as f:
try:
dct = json.loads(f[self.name].attrs['random_state'])
dct = json.loads(f[self.name].attrs["random_state"])
except KeyError:
return None
return dct
Expand Down Expand Up @@ -269,8 +269,7 @@ def save_step(self, state, accepted):
g["accepted"][:] += accepted

g.attrs["random_state"] = json.dumps(
state.random_state,
cls=NumpyEncoder
state.random_state, cls=NumpyEncoder
)

g.attrs["iteration"] = iteration + 1
Expand Down
12 changes: 8 additions & 4 deletions src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,18 @@ def random_state(self):
so silently.

"""

def rng_dict(rng):
bg_state = rng.bit_generator.state
ss = rng.bit_generator.seed_seq
ss_dict = dict(
entropy=ss.entropy,
spawn_key=ss.spawn_key,
pool_size=ss.pool_size,
n_children_spawned=ss.n_children_spawned
n_children_spawned=ss.n_children_spawned,
)
return dict(bg_state=bg_state, seed_seq=ss_dict)

return rng_dict(self._random)
# return self._random.bit_generator.state

Expand All @@ -242,13 +244,15 @@ def random_state(self, state):
if it doesn't work. Don't say I didn't warn you...

"""

def _rng_fromdict(d):
bg_state = d['bg_state']
ss = np.random.SeedSequence(**d['seed_seq'])
bg = getattr(np.random, bg_state['bit_generator'])(ss)
bg_state = d["bg_state"]
ss = np.random.SeedSequence(**d["seed_seq"])
bg = getattr(np.random, bg_state["bit_generator"])(ss)
bg.state = bg_state
rng = np.random.Generator(bg)
return rng

try:
self._random = _rng_fromdict(state)
# self._random.bit_generator = state
Expand Down
4 changes: 3 additions & 1 deletion src/emcee/moves/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def get_proposal(self, s, c, random):
diffs = np.diff(c[pairs], axis=1).squeeze(axis=1) # (ns, ndim)

# Sample a gamma value for each walker following Nelson et al. (2013)
gamma = self.g0 * (1 + self.sigma * random.standard_normal((ns, 1))) # (ns, 1)
gamma = self.g0 * (
1 + self.sigma * random.standard_normal((ns, 1))
) # (ns, 1)

# In this way, sigma is the standard deviation of the distribution of gamma,
# instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006).
Expand Down
8 changes: 6 additions & 2 deletions src/emcee/moves/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def get_factor(self, rng):
return np.exp(rng.uniform(-self._log_factor, self._log_factor))

def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)

def __call__(self, x0, rng):
nw, nd = x0.shape
Expand All @@ -106,7 +108,9 @@ def __call__(self, x0, rng):

class _diagonal_proposal(_isotropic_proposal):
def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal(
(x0.shape)
)


class _proposal(_isotropic_proposal):
Expand Down
6 changes: 3 additions & 3 deletions src/emcee/tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_backend(backend, dtype, blobs):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
if blobs:
_custom_allclose(last1.blobs, last2.blobs)
else:
Expand All @@ -137,7 +137,7 @@ def test_backend(backend, dtype, blobs):

@pytest.mark.parametrize("backend,dtype", product(other_backends, dtypes))
def test_reload(backend, dtype):
with (backend() as backend1):
with backend() as backend1:
run_sampler(backend1, dtype=dtype)

# Test the state
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_restart(backend, dtype):
last2 = sampler2.get_last_sample()
assert np.allclose(last1.coords, last2.coords)
assert np.allclose(last1.log_prob, last2.log_prob)
assert last1.random_state['bg_state'] == last2.random_state['bg_state']
assert last1.random_state["bg_state"] == last2.random_state["bg_state"]
_custom_allclose(last1.blobs, last2.blobs)

a = sampler1.acceptance_fraction
Expand Down
4 changes: 3 additions & 1 deletion src/emcee/tests/unit/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def run_sampler(
):
rng = np.random.default_rng(seed)
coords = rng.standard_normal((nwalkers, ndim))
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, rng=rng, backend=backend)
sampler = EnsembleSampler(
nwalkers, ndim, normal_log_prob, rng=rng, backend=backend
)
sampler.run_mcmc(
coords,
nsteps,
Expand Down
Loading