Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bilby/core/prior/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def condition_func(reference_params, y):
self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__)
self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__)

# no conditional prior (including DeltaFunction) is fixed a priori
self._is_fixed = False

def sample(self, size=None, **required_variables):
"""Draw a sample from the prior

Expand Down
125 changes: 76 additions & 49 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,40 +734,37 @@ def _check_conditions_resolved(self, key, sampled_keys):
conditions_resolved = False
return conditions_resolved

def _resolve_subset_conditions(self, keys):
try:
# if subset is already resolved, nothing to do
self._check_subset_resolved(keys)
except IllegalConditionsException:
# updates to the dict might require re-resolving conditions
self._resolve_conditions()
self._check_subset_resolved(keys)

def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
add_delta_keys = [
key
for key in self.keys()
if key not in keys and isinstance(self[key], DeltaFunction)
]
use_keys = add_delta_keys + list(keys)
subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys})
if not subset_dict._resolved:
raise IllegalConditionsException(
"The current set of priors contains unresolvable conditions."
)

use_keys = self._prepare_evaluation(keys)

samples = dict()
for key in subset_dict.sorted_keys:
if key not in keys or isinstance(self[key], Constraint):
use_keys_sorted = [key for key in self.sorted_keys if key in use_keys]
for key in use_keys_sorted:
if isinstance(self[key], Constraint):
continue
if isinstance(self[key], Prior):
try:
samples[key] = subset_dict[key].sample(
size=size, **subset_dict.get_required_variables(key)
)
except ValueError:
# Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
# If that is the case, we sample each sample individually.
required_variables = subset_dict.get_required_variables(key)
samples[key] = np.zeros(size)
for i in range(size):
rvars = {
key: value[i] for key, value in required_variables.items()
}
samples[key][i] = subset_dict[key].sample(**rvars)
else:
logger.debug("{} not a known prior.".format(key))
try:
samples[key] = self[key].sample(
size=size, **self.get_required_variables(key)
)
except ValueError:
# Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
# If that is the case, we sample each sample individually.
required_variables = self.get_required_variables(key)
samples[key] = np.zeros(size)
for i in range(size):
rvars = {key: value[i] for key, value in required_variables.items()}
samples[key][i] = self[key].sample(**rvars)

return samples

def get_required_variables(self, key):
Expand Down Expand Up @@ -861,15 +858,14 @@ def rescale(self, keys, theta):
"""
keys = list(keys)
theta = list(theta)
self._check_resolved()
self._update_rescale_keys(keys)
theta = {key: value for key, value in zip(keys, theta)}
self._prepare_evaluation(keys)
sorted_keys = [key for key in self.sorted_keys if key in keys]
result = dict()
joint = dict()
for key, index in zip(
self.sorted_keys_without_fixed_parameters, self._rescale_indexes
):
for key in sorted_keys:
result[key] = self[key].rescale(
theta[index], **self.get_required_variables(key)
theta[key], **self.get_required_variables(key)
)
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
Expand Down Expand Up @@ -902,25 +898,56 @@ def safe_flatten(value):

return [safe_flatten(result[key]) for key in keys]

def _update_rescale_keys(self, keys):
if not keys == self._least_recently_rescaled_keys:
self._rescale_indexes = [
keys.index(element)
for element in self.sorted_keys_without_fixed_parameters
]
self._least_recently_rescaled_keys = keys

def _prepare_evaluation(self, keys, theta):
self._check_resolved()
for key, value in zip(keys, theta):
self[key].least_recently_sampled = value
def _prepare_evaluation(self, keys, theta=None):
self.convert_floats_to_delta_functions()
# Add all fixed priors, (unconditional delta functions)
# that are not already in keys as they may be required to resolve conditions
fixed_keys = [key for key in self.fixed_keys if key not in keys]
use_keys = fixed_keys + list(keys)
self._resolve_subset_conditions(use_keys)
if theta is not None:
for key, value in zip(keys, theta):
self[key].least_recently_sampled = value
return use_keys

def _check_resolved(self):
if not self._resolved:
raise IllegalConditionsException(
"The current set of priors contains unresolveable conditions."
)

def _check_subset_resolved(self, keys):
"""Checks if a subset of keys can be sampled given the current conditions"""
resolved = True
subset_keys = list(keys)
subset_keys_sorted = [key for key in self.sorted_keys if key in subset_keys]
if len(subset_keys_sorted) != len(subset_keys):
resolved = False
logger.debug(
"The requested subset {} of priors contains {} keys ({}) that are not in the prior dict.".format(
keys,
len(subset_keys) - len(subset_keys_sorted),
list(set(subset_keys) - set(subset_keys_sorted)),
)
)
for key in subset_keys_sorted:
# if one key is not resolved, break early
if not resolved:
break
if isinstance(self[key], JointPrior):
if not set(self[key].dist.names).issubset(subset_keys_sorted):
resolved = False
if key in self._conditional_keys:
# we can check against the sorted keys as those are already resolved in order
resolved = self._check_conditions_resolved(
key, subset_keys_sorted
)

if not resolved:
raise IllegalConditionsException(
f"The requested subset {keys} of priors contains unresolveable conditions."
)

@property
def conditional_keys(self):
return self._conditional_keys
Expand Down
11 changes: 11 additions & 0 deletions test/core/prior/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,17 @@ def test_rescale(self):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res)

def test_rescale_subset(self):
test_sample_subset = {
key: self.test_sample[key] for key in ["var_0", "var_1", "var_2"]
}

res = self.conditional_priors.rescale(
keys=list(test_sample_subset.keys()),
theta=list(test_sample_subset.values()),
)
assert len(res) == 3

def test_rescale_with_joint_prior(self):
"""
Add a joint prior into the conditional prior dictionary and check that
Expand Down
Loading