-
Notifications
You must be signed in to change notification settings - Fork 122
ENH: add support for new dynesty api #950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7e3c0bb to
543e10c
Compare
|
@ColmTalbot could post a diff of the old and new utils? |
I've put it below, but it isn't very useful as it's basically a rewrite. Maybe you can see that the MCMC portion is unchanged. Full diff1a2
> from collections import namedtuple
4c5
< from dynesty.nestedsamplers import MultiEllipsoidSampler, UnitCubeSampler
---
> from dynesty.sampling import InternalSampler
9d9
< from .base_sampler import _SamplingContainer
10a11,35
> EnsembleSamplerArgument = namedtuple(
> "EnsembleSamplerArgument",
> [
> "u",
> "loglstar",
> "live_points",
> "prior_transform",
> "loglikelihood",
> "rseed",
> "kwargs",
> ],
> )
> EnsembleAxisSamplerArgument = namedtuple(
> "EnsembleAxisSamplerArgument",
> [
> "u",
> "loglstar",
> "axes",
> "live_points",
> "prior_transform",
> "loglikelihood",
> "rseed",
> "kwargs",
> ],
> )
12,16d36
< class LivePointSampler(UnitCubeSampler):
< """
< Modified version of dynesty UnitCubeSampler that adapts the MCMC
< length in addition to the proposal scale, this corresponds to
< :code:`bound=live`.
18,20c38,43
< In order to support live-point based proposals, e.g., differential
< evolution (:code:`diff`), the live points are added to the
< :code:`kwargs` passed to the evolve method.
---
> class BaseEnsembleSampler(InternalSampler):
> def __init__(self, **kwargs):
> super().__init__(**kwargs)
> self.ncdim = kwargs.get("ncdim")
> self.sampler_kwargs["ncdim"] = self.ncdim
> self.sampler_kwargs["proposals"] = kwargs.get("proposals", ["diff"])
22,26c45,56
< Note that this does not perform ellipsoid clustering as with the
< :code:`bound=multi` option, if ellipsoid-based proposals are used, e.g.,
< :code:`volumetric`, consider using the
< :code:`MultiEllipsoidLivePointSampler` (:code:`sample=live-multi`).
< """
---
> def prepare_sampler(
> self,
> loglstar=None,
> points=None,
> axes=None,
> seeds=None,
> prior_transform=None,
> loglikelihood=None,
> nested_sampler=None,
> ):
> """
> Prepare the list of arguments for sampling.
28c58,75
< rebuild = False
---
> Parameters
> ----------
> loglstar : float
> Ln(likelihood) bound.
> points : `~numpy.ndarray` with shape (n, ndim)
> Initial sample points.
> axes : `~numpy.ndarray` with shape (ndim, ndim)
> Axes used to propose new points.
> seeds : `~numpy.ndarray` with shape (n,)
> Random number generator seeds.
> prior_transform : function
> Function transforming a sample from the a unit cube to the
> parameter space of interest according to the prior.
> loglikelihood : function
> Function returning ln(likelihood) given parameters as a 1-d
> `~numpy` array of length `ndim`.
> nested_sampler : `~dynesty.samplers.Sampler`
> The nested sampler object used to sample.
30c77,81
< def update_user(self, blob, update=True):
---
> Returns
> -------
> arglist:
> List of `SamplerArgument` objects containing the parameters
> needed for sampling.
31a83,114
> arg_list = []
> kwargs = self.sampler_kwargs
> self.nlive = nested_sampler.nlive
> for curp, curaxes, curseed in zip(points, axes, seeds):
> vals = dict(
> u=curp,
> loglstar=loglstar,
> live_points=nested_sampler.live_u,
> prior_transform=prior_transform,
> loglikelihood=loglikelihood,
> rseed=curseed,
> kwargs=kwargs,
> )
> if "volumetric" in kwargs["proposals"]:
> vals["axes"] = curaxes
> curarg = EnsembleAxisSamplerArgument(**vals)
> else:
> curarg = EnsembleSamplerArgument(**vals)
> arg_list.append(curarg)
> return arg_list
>
>
> class EnsembleWalkSampler(BaseEnsembleSampler):
> def __init__(self, **kwargs):
> super().__init__(**kwargs)
> self.walks = max(2, kwargs.get("walks", 25))
> self.sampler_kwargs["walks"] = self.walks
> self.naccept = kwargs.get("naccept", 10)
> self.maxmcmc = kwargs.get("maxmcmc", 5000)
>
> def tune(self, sampling_info, update=True):
> """
35,42c118,119
< There are a number of logical checks performed:
< - if the ACT tracking rwalk method is being used and any parallel
< process has an empty cache, set the :code:`rebuild` flag to force
< the cache to rebuild at the next call. This improves the efficiency
< when using parallelisation.
< - update the :code:`walks` parameter to asymptotically approach the
< desired number of accepted steps for the :code:`FixedRWalk` proposal.
< - update the ellipsoid scale if the ellipsoid proposals are being used.
---
> The :code:`walks` parameter to asymptotically approach the
> desired number of accepted steps.
44,50d120
< # do we need to trigger rebuilding the cache
< if blob.get("remaining", 0) == 1:
< self.rebuild = True
< if update:
< self.kwargs["rebuild"] = self.rebuild
< self.rebuild = False
<
52c122
< accept_prob = max(0.5, blob["accept"]) / self.kwargs["walks"]
---
> accept_prob = max(0.5, sampling_info["accept"]) / self.sampler_kwargs["walks"]
54,56c124,125
< n_target = getattr(_SamplingContainer, "naccept", 60)
< self.walks = (self.walks * delay + n_target / accept_prob) / (delay + 1)
< self.kwargs["walks"] = min(int(np.ceil(self.walks)), _SamplingContainer.maxmcmc)
---
> self.walks = (self.walks * delay + self.naccept / accept_prob) / (delay + 1)
> self.sampler_kwargs["walks"] = min(int(np.ceil(self.walks)), self.maxmcmc)
58c127
< self.scale = blob["accept"]
---
> self.scale = sampling_info["accept"]
60,62c129,130
< update_rwalk = update_user
<
< def propose_live(self, *args):
---
> @staticmethod
> def sample(args):
64,71c132,133
< We need to make sure the live points are passed to the proposal
< function if we are using live point-based proposals.
< """
< self.kwargs["nlive"] = self.nlive
< self.kwargs["live"] = self.live_u
< i = self.rstate.integers(self.nlive)
< u = self.live_u[i, :]
< return u, np.identity(self.ncdim)
---
> Return a new live point proposed by random walking away from an
> existing live point.
72a135,139
> Parameters
> ----------
> u : `~numpy.ndarray` with shape (ndim,)
> Position of the initial sample. **This is a copy of an existing
> live point.**
74,78c141,142
< class MultiEllipsoidLivePointSampler(MultiEllipsoidSampler):
< """
< Modified version of dynesty MultiEllipsoidSampler that adapts the MCMC
< length in addition to the proposal scale, this corresponds to
< :code:`bound=live-multi`.
---
> loglstar : float
> Ln(likelihood) bound.
80,82c144,147
< Additionally, in order to support live point-based proposals, e.g.,
< differential evolution (:code:`diff`), the live points are added to the
< :code:`kwargs` passed to the evolve method.
---
> axes : `~numpy.ndarray` with shape (ndim, ndim)
> Axes used to propose new points. For random walks new positions are
> proposed using the :class:`~dynesty.bounding.Ellipsoid` whose
> shape is defined by axes.
84,86c149,150
< When just using the :code:`diff` proposal method, consider using the
< :code:`LivePointSampler` (:code:`sample=live`).
< """
---
> scale : float
> Value used to scale the provided axes.
88c152,154
< rebuild = False
---
> prior_transform : function
> Function transforming a sample from the a unit cube to the
> parameter space of interest according to the prior.
90,94c156,158
< def update_user(self, blob, update=True):
< LivePointSampler.update_user(self, blob=blob, update=update)
< super(MultiEllipsoidLivePointSampler, self).update_rwalk(
< blob=blob, update=update
< )
---
> loglikelihood : function
> Function returning ln(likelihood) given parameters as a 1-d
> `~numpy` array of length `ndim`.
96c160,161
< update_rwalk = update_user
---
> kwargs : dict
> A dictionary of additional method-specific parameters.
98,105c163,166
< def propose_live(self, *args):
< """
< We need to make sure the live points are passed to the proposal
< function if we are using ensemble proposals.
< """
< self.kwargs["nlive"] = self.nlive
< self.kwargs["live"] = self.live_u
< return super(MultiEllipsoidLivePointSampler, self).propose_live(*args)
---
> Returns
> -------
> u : `~numpy.ndarray` with shape (ndim,)
> Position of the final proposed point within the unit cube.
106a168,169
> v : `~numpy.ndarray` with shape (ndim,)
> Position of the final proposed point in the target parameter space.
108,113c171,172
< class FixedRWalk:
< """
< Run the MCMC walk for a fixed length. This is nearly equivalent to
< :code:`bilby.sampling.sample_rwalk` except that different proposal
< distributions can be used.
< """
---
> logl : float
> Ln(likelihood) of the final proposed point.
115c174,180
< def __call__(self, args):
---
> nc : int
> Number of function calls used to generate the sample.
>
> sampling_info : dict
> Collection of ancillary quantities used to tune :data:`scale`.
>
> """
127,128d191
< accepted = list()
<
135d197
< accepted.append(0)
147,149d208
< accepted.append(1)
< else:
< accepted.append(0)
163c222
< blob = {
---
> sampling_info = {
166d224
< "scale": args.scale,
169c227
< return current_u, current_v, logl, ncall, blob
---
> return current_u, current_v, logl, ncall, sampling_info
172c230
< class ACTTrackingRWalk:
---
> class ACTTrackingEnsembleWalk(BaseEnsembleSampler):
188c246,247
< def __init__(self):
---
> def __init__(self, **kwargs):
> super().__init__(**kwargs)
190,191c249,254
< self.thin = getattr(_SamplingContainer, "nact", 2)
< self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000) * 50
---
> self.thin = kwargs.get("nact", 2)
> self.maxmcmc = kwargs.get("maxmcmc", 5000) * 50
> self.sampler_kwargs["rebuild"] = True
> self.sampler_kwargs["thin"] = self.thin
> self.sampler_kwargs["act"] = self.act
> self.sampler_kwargs["maxmcmc"] = self.maxmcmc
193,194c256,321
< def __call__(self, args):
< self.args = args
---
> def prepare_sampler(
> self,
> loglstar=None,
> points=None,
> axes=None,
> seeds=None,
> prior_transform=None,
> loglikelihood=None,
> nested_sampler=None,
> ):
> """
> Prepare the list of arguments for sampling.
>
> Parameters
> ----------
> loglstar : float
> Ln(likelihood) bound.
> points : `~numpy.ndarray` with shape (n, ndim)
> Initial sample points.
> axes : `~numpy.ndarray` with shape (ndim, ndim)
> Axes used to propose new points.
> seeds : `~numpy.ndarray` with shape (n,)
> Random number generator seeds.
> prior_transform : function
> Function transforming a sample from the a unit cube to the
> parameter space of interest according to the prior.
> loglikelihood : function
> Function returning ln(likelihood) given parameters as a 1-d
> `~numpy` array of length `ndim`.
> nested_sampler : `~dynesty.samplers.Sampler`
> The nested sampler object used to sample.
>
> Returns
> -------
> arglist:
> List of `SamplerArgument` objects containing the parameters
> needed for sampling.
> """
> arg_list = super().prepare_sampler(
> loglstar=loglstar,
> points=points,
> axes=axes,
> seeds=seeds,
> prior_transform=prior_transform,
> loglikelihood=loglikelihood,
> nested_sampler=nested_sampler,
> )
> self.sampler_kwargs["rebuild"] = False
> return arg_list
>
> def tune(self, sampling_info, update=True):
> """
> Update the proposal parameters based on the number of accepted steps
> and MCMC chain length.
>
> The :code:`walks` parameter to asymptotically approach the
> desired number of accepted steps.
> """
> if sampling_info.get("remaining", 0) == 0:
> self.sampler_kwargs["rebuild"] = True
> self.scale = sampling_info["accept"]
> self.sampler_kwargs["act"] = sampling_info["act"]
>
> @staticmethod
> def sample(args):
> cache = ACTTrackingEnsembleWalk._cache
196,201c323,333
< logger.debug("Force rebuilding cache")
< self.build_cache()
< while self.cache[0][2] < args.loglstar:
< self.cache.pop(0)
< current_u, current_v, logl, ncall, blob = self.cache.pop(0)
< blob["remaining"] = len(self.cache)
---
> logger.debug(f"Force rebuilding cache with {len(cache)} remaining")
> ACTTrackingEnsembleWalk.build_cache(args)
> elif len(cache) == 0:
> ACTTrackingEnsembleWalk.build_cache(args)
> while len(cache) > 0 and cache[0][2] < args.loglstar:
> state = cache.pop(0)
> if len(cache) == 0:
> current_u, current_v, logl, ncall, blob = state
> else:
> current_u, current_v, logl, ncall, blob = cache.pop(0)
> blob["remaining"] = len(cache)
204,213c336,337
< @property
< def cache(self):
< if len(self._cache) == 0:
< self.build_cache()
< else:
< logger.debug(f"Not rebuilding cache, remaining size {len(self._cache)}")
< return self._cache
<
< def build_cache(self):
< args = self.args
---
> @staticmethod
> def build_cache(args):
223c347
< check_interval = self.integer_act
---
> check_interval = ACTTrackingEnsembleWalk.integer_act(args.kwargs["act"])
243c367
< while iteration < min(target_nact * act, self.maxmcmc):
---
> while iteration < min(target_nact * act, args.kwargs["maxmcmc"]):
278c402
< act = self._calculate_act(
---
> act = ACTTrackingEnsembleWalk._calculate_act(
295c419
< self.act = self._calculate_act(
---
> act = ACTTrackingEnsembleWalk._calculate_act(
302,304c426,428
< blob = {"accept": accept, "reject": reject, "scale": args.scale}
< iact = self.integer_act
< thin = self.thin * iact
---
> blob = {"accept": accept, "reject": reject, "act": act}
> iact = ACTTrackingEnsembleWalk.integer_act(act)
> thin = args.kwargs["thin"] * iact
305a430,431
> cache = ACTTrackingEnsembleWalk._cache
>
313c439
< self._cache.append((u, v, logl, ncall, blob))
---
> cache.append((u, v, logl, ncall, blob))
318,320c444,446
< self._cache.append((current_u, current_v, logl, ncall, blob))
< elif (self.thin == -1) or (len(u_list) <= thin):
< self._cache.append((current_u, current_v, logl, ncall, blob))
---
> cache.append((current_u, current_v, logl, ncall, blob))
> elif (thin == -1) or (len(u_list) <= thin):
> cache.append((current_u, current_v, logl, ncall, blob))
331c457
< dict(accept=accept, reject=reject, fail=nfail, scale=args.scale)
---
> dict(accept=accept, reject=reject, fail=nfail, act=act)
333c459
< self._cache.extend(zip(u_list, v_list, logl_list, ncall_list, blob_list))
---
> cache.extend(zip(u_list, v_list, logl_list, ncall_list, blob_list))
335c461
< f"act: {self.act:.2f}, max failures: {most_failures}, thin: {thin}, "
---
> f"act: {act:.2f}, max failures: {most_failures}, thin: {thin}, "
339,340c465,466
< f"Finished building cache with length {len(self._cache)} after "
< f"{iteration} iterations with {ncall} likelihood calls and ACT={self.act:.2f}"
---
> f"Finished building cache with length {len(cache)} after "
> f"{iteration} iterations with {ncall} likelihood calls and ACT={act:.2f}"
362,365c488,491
< @property
< def integer_act(self):
< if np.isinf(self.act):
< return self.act
---
> @staticmethod
> def integer_act(act):
> if np.isinf(act):
> return act
367c493
< return int(np.ceil(self.act))
---
> return int(np.ceil(act))
370c496
< class AcceptanceTrackingRWalk:
---
> class AcceptanceTrackingRWalk(EnsembleWalkSampler):
384,386c510,514
< def __init__(self, old_act=None):
< self.maxmcmc = getattr(_SamplingContainer, "maxmcmc", 5000)
< self.nact = getattr(_SamplingContainer, "nact", 40)
---
> def __init__(self, **kwargs):
> super().__init__(**kwargs)
> self.nact = kwargs.get("nact", 40)
> self.sampler_kwargs["nact"] = self.nact
> self.sampler_kwargs["maxmcmc"] = self.maxmcmc
388c516,517
< def __call__(self, args):
---
> @staticmethod
> def sample(args):
395c524
< u = args.u
---
> current_u = args.u
403a533,534
> nact = args.kwargs["nact"]
> maxmcmc = args.kwargs["maxmcmc"]
406c537
< while iteration < self.nact * act:
---
> while iteration < nact * act:
410c541,543
< u_prop = proposal_funcs[prop](u, **common_kwargs, **proposal_kwargs[prop])
---
> u_prop = proposal_funcs[prop](
> current_u, **common_kwargs, **proposal_kwargs[prop]
> )
421,422c554,555
< u = u_prop
< v = v_prop
---
> current_u = u_prop
> current_v = v_prop
429,430c562,563
< if iteration > self.nact:
< act = self.estimate_nmcmc(
---
> if iteration > nact:
> act = AcceptanceTrackingRWalk.estimate_nmcmc(
433a567,568
> maxmcmc=maxmcmc,
> old_act=AcceptanceTrackingRWalk.old_act,
437c572
< if accept + reject > self.maxmcmc:
---
> if accept + reject > maxmcmc:
439c574
< f"Hit maximum number of walks {self.maxmcmc} with accept={accept},"
---
> f"Hit maximum number of walks {maxmcmc} with accept={accept},"
448,450c583,585
< u = rstate.uniform(size=len(u))
< v = args.prior_transform(u)
< logl = args.loglikelihood(v)
---
> current_u = rstate.uniform(size=len(current_u))
> current_v = args.prior_transform(current_u)
> logl = args.loglikelihood(current_v)
452c587
< blob = {"accept": accept, "reject": reject + nfail, "scale": args.scale}
---
> blob = {"accept": accept, "reject": reject + nfail}
456c591
< return u, v, logl, ncall, blob
---
> return current_u, current_v, logl, ncall, blob
458c593,594
< def estimate_nmcmc(self, accept_ratio, safety=5, tau=None):
---
> @staticmethod
> def estimate_nmcmc(accept_ratio, safety=5, tau=None, maxmcmc=5000, old_act=None):
485c621
< tau = self.maxmcmc / safety
---
> tau = maxmcmc / safety
488c624
< if self.old_act is None:
---
> if old_act is None:
491c627
< Nmcmc_exact = (1 + 1 / tau) * self.old_act
---
> Nmcmc_exact = (1 + 1 / tau) * old_act
495,497c631,633
< if self.old_act is not None:
< Nmcmc_exact = (1 - 1 / tau) * self.old_act + Nmcmc_exact / tau
< Nmcmc_exact = float(min(Nmcmc_exact, self.maxmcmc))
---
> if old_act is not None:
> Nmcmc_exact = (1 - 1 / tau) * old_act + Nmcmc_exact / tau
> Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
527d662
< n_cluster = args.axes.shape[0]
529c664
< proposals = getattr(_SamplingContainer, "proposals", None)
---
> proposals = args.kwargs.get("proposals", None)
531a667,674
>
> n_cluster = args.kwargs.get("ncdim", None)
> if n_cluster is None:
> if hasattr(args, "live_points"):
> n_cluster = args.live_points.shape[1]
> elif hasattr(args, "axes"):
> n_cluster = args.axes.shape[0]
>
533c676
< live = args.kwargs.get("live", None)
---
> live = args.live_points
570c713
< def propose_differetial_evolution(
---
> def propose_differential_evolution(
725c868,870
< proposal_funcs = dict(diff=propose_differetial_evolution, volumetric=propose_volumetric)
---
> proposal_funcs = dict(
> diff=propose_differential_evolution, volumetric=propose_volumetric
> ) |
|
This looks good to me. I think the duplication of the utils is probably a good idea as a temporary measure to treat them as separate entities. |
|
@ColmTalbot is there any way we could test this in the CI by installing the current main from dynesty? |
abe868c to
5998eb7
Compare
|
I'll see what I can do. |
|
This revealed that some changes were needed. I had to make one change to the checkpoint pickle file we make that should be backward compatible and we don't guarantee backward compatibility of the pickle files anyway. |
mj-will
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, just one comment on how we handle LVK review for this.
GregoryAshton
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me other than minor comments
| self.sampling_time = self.sampler.kwargs.pop("sampling_time") | ||
| if hasattr(self.sampler, "_bilby_metadata"): | ||
| extras = self.sampler._bilby_metadata | ||
| elif hasattr(self.sampler, "kwargs"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused here. Why can the sampler have either metadata or kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new API the dynesty sampler classes don't have a kwargs attribute that we can piggyback on.
|
@ColmTalbot since dynesty 3.0 has now been released, does this need any additional changes or should we just merge it? |
|
I would say get it in, hopefully the release existing means the API is stable. |
* ENH: add support for new dynesty api * FMT: fix precommits * FMT: try to fix precommits * TST: add tests with dynesty master and related fixes * MAINT: changes to checkpoint file for new dynesty API * FMT: pre commit fixes * CI: only run dynesty tests on one python version
In preparation for joshspeagle/dynesty#495 this would allow us to (temporarily) support both versions of dynesty. There's some duplication, but this means that when we want to stop supporting the older version of dynesty, we can just delete the old
dynesty_utils.pyand some of the other methods.This will be a draft for a while until the implementation in dynesty is finalized.