Skip to content

Commit 2dd05b8

Browse files
committed
lnprior_ext
Adds a `lnprior_ext` argument that allows users to pass additional (Gaussian) priors over model labels on an object-by-object basis. Ideally designed to be used when additional constraints (from, e.g., spectroscopy or photometry) are available. Resolves #45, although I haven't tested it very much.
1 parent ab22380 commit 2dd05b8

File tree

1 file changed

+47
-2
lines changed

1 file changed

+47
-2
lines changed

brutus/fitting.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,7 +1376,8 @@ def fit(self, data, data_err, data_mask, data_labels, save_file,
13761376
phot_offsets=None, parallax=None, parallax_err=None,
13771377
Nmc_prior=50, avlim=(0., 20.), av_gauss=None,
13781378
rvlim=(1., 8.), rv_gauss=(3.32, 0.18),
1379-
lnprior=None, wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
1379+
lnprior=None, lnprior_ext=None,
1380+
wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
13801381
apply_agewt=True, apply_grad=True,
13811382
lngalprior=None, lndustprior=None, dustfile=None,
13821383
apply_dlabels=True, data_coords=None, logl_dim_prior=True,
@@ -1448,6 +1449,15 @@ def fit(self, data, data_err, data_mask, data_labels, save_file,
14481449
being drawn from a uniform prior from `smf=[0., 1.]`.
14491450
**Be sure to check this behavior you are using custom models.**
14501451
1452+
lnprior_ext : dict with entries of shape `(Nobj, 2)`, optional
1453+
External prior constraints over input model labels. These must
1454+
be passed as a dictionary with entries that **exactly** correspond
1455+
to the input model labels (e.g., `feh`). Each entry must be an
1456+
iterable with shape `(Nobj, 2)`, which are taken to correspond
1457+
to the mean and standard deviation, respectively, of a Gaussian
1458+
prior. Unlike `lnprior`, which is applied to all objects,
1459+
`lnprior_ext` can vary on a per-object basis.
1460+
14511461
wt_thresh : float, optional
14521462
The threshold `wt_thresh * max(y_wt)` used to ignore models
14531463
with (relatively) negligible weights when resampling.
@@ -1623,6 +1633,7 @@ def fit(self, data, data_err, data_mask, data_labels, save_file,
16231633
rv_gauss=rv_gauss,
16241634
Nmc_prior=Nmc_prior,
16251635
lnprior=lnprior,
1636+
lnprior_ext=lnprior_ext,
16261637
wt_thresh=wt_thresh,
16271638
cdf_thresh=cdf_thresh,
16281639
Ndraws=Ndraws, rstate=rstate,
@@ -1736,7 +1747,8 @@ def _fit(self, data, data_err, data_mask,
17361747
parallax=None, parallax_err=None, Nmc_prior=100,
17371748
avlim=(0., 20.), av_gauss=None,
17381749
rvlim=(1., 8.), rv_gauss=(3.32, 0.18),
1739-
lnprior=None, wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
1750+
lnprior=None, lnprior_ext=None,
1751+
wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
17401752
lngalprior=None, lndustprior=None, dustfile=None,
17411753
apply_dlabels=True, data_coords=None,
17421754
return_distreds=True, logl_dim_prior=True, ltol=3e-2,
@@ -1788,6 +1800,15 @@ def _fit(self, data, data_err, data_mask,
17881800
Log-prior grid to be used. If not provided, will default
17891801
to `0.`.
17901802
1803+
lnprior_ext : dict with entries of shape `(Nobj, 2)`, optional
1804+
External prior constraints over input model labels. These must
1805+
be passed as a dictionary with entries that **exactly** correspond
1806+
to the input model labels (e.g., `feh`). Each entry must be an
1807+
iterable with shape `(Nobj, 2)`, which are taken to correspond
1808+
to the mean and standard deviation, respectively, of a Gaussian
1809+
prior. Unlike `lnprior`, which is applied to all objects,
1810+
`lnprior_ext` can vary on a per-object basis.
1811+
17911812
wt_thresh : float, optional
17921813
The threshold `wt_thresh * max(y_wt)` used to ignore models
17931814
with (relatively) negligible weights.
@@ -1882,6 +1903,13 @@ def _fit(self, data, data_err, data_mask,
18821903
dlabels = self.models_labels
18831904
else:
18841905
dlabels = None
1906+
if lnprior_ext is not None:
1907+
ext_keys = lnprior_ext.keys() # external keys
1908+
for k in ext_keys:
1909+
if k not in self.models_labels.dtype.names:
1910+
raise ValueError("Provided `lnprior_ext` has keys which "
1911+
"do not match the underlying model "
1912+
"labels.")
18851913

18861914
# Main generator for fitting data.
18871915
Ndata, Nfilt = data.shape
@@ -1899,6 +1927,23 @@ def _fit(self, data, data_err, data_mask,
18991927
return_vals=True)
19001928
lnlike, Ndim, chi2, scales, avs, rvs, icovs_sar = results
19011929

1930+
# Apply external prior constraints.
1931+
if lnprior_ext is not None:
1932+
for k in ext_keys:
1933+
# Grab Gaussian parameters for each external constraint.
1934+
ext_mean, ext_std = lnprior_ext[k][i]
1935+
if np.isfinite(ext_mean) and ext_std > 0:
1936+
# Calculate chi2 and constant.
1937+
ext_ivar = 1. / ext_std**2
1938+
ext_chi2 = (self.models_labels[k] - ext_mean)**2
1939+
ext_chi2 *= ext_ivar
1940+
ext_const = np.log(2. * np.pi * ext_std**2)
1941+
# Calculate external prior constraint.
1942+
ext_lnp = -0.5 * (ext_chi2 + ext_const)
1943+
# Add to directly log-likelihood.
1944+
lnlike += ext_lnp
1945+
results = (lnlike, Ndim, chi2, scales, avs, rvs, icovs_sar)
1946+
19021947
# Compute posteriors using Monte Carlo sampling and integration.
19031948
presults = lnpost(results, parallax=parallax[i],
19041949
parallax_err=parallax_err[i],

0 commit comments

Comments
 (0)