Skip to content

Commit 9450a9d

Browse files
authored
improving posterior clipping
Attempt at improving the process at which the initial posteriors are clipped by applying priors to just the MLE solution over a subset of the models. Lowers the initial `wt_thresh` defaults.
1 parent 1cf02b7 commit 9450a9d

File tree

1 file changed

+38
-16
lines changed

1 file changed

+38
-16
lines changed

brutus/fitting.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ def loglike(data, data_err, data_mask, mag_coeffs,
821821

822822

823823
def lnpost(results, parallax=None, parallax_err=None, coord=None,
824-
Nmc_prior=100, lnprior=None, wt_thresh=5e-3, cdf_thresh=2e-3,
824+
Nmc_prior=100, lnprior=None, wt_thresh=1e-3, cdf_thresh=2e-3,
825825
lngalprior=None, lndustprior=None, dustfile=None,
826826
dlabels=None, avlim=(0., 20.), rvlim=(1., 8.),
827827
rstate=None, apply_av_prior=True, *args, **kwargs):
@@ -859,7 +859,7 @@ def lnpost(results, parallax=None, parallax_err=None, coord=None,
859859
wt_thresh : float, optional
860860
The threshold `wt_thresh * max(y_wt)` used to ignore models
861861
with (relatively) negligible weights.
862-
Default is `5e-3`.
862+
Default is `1e-3`.
863863
864864
cdf_thresh : float, optional
865865
The `1 - cdf_thresh` threshold of the (sorted) CDF used to ignore
@@ -963,22 +963,19 @@ def lnpost(results, parallax=None, parallax_err=None, coord=None,
963963
# Grab results.
964964
lnlike, Ndim, chi2, scales, avs, rvs, icovs_sar = results
965965

966-
# Compute initial log-posteriors.
967-
lnp = lnlike + lnprior
968-
969966
# Apply rough parallax prior for clipping.
970967
if parallax is not None and parallax_err is not None:
971968
ds2 = icovs_sar[:, 0, 0]
972969
scales_err = 1. / np.sqrt(np.abs(ds2)) # approximate scale errors
973-
lnprob = lnp + scale_parallax_lnprior(scales, scales_err,
974-
parallax, parallax_err)
970+
lnprob = lnlike + scale_parallax_lnprior(scales, scales_err,
971+
parallax, parallax_err)
975972
else:
976-
lnprob = lnp
973+
lnprob = lnlike
977974
lnprob_mask = np.where(~np.isfinite(lnprob))[0] # safety check
978975
if len(lnprob_mask) > 0:
979976
lnprob[lnprob_mask] = -1e300
980977

981-
# Apply thresholding.
978+
# Apply likelihood thresholding.
982979
if wt_thresh is not None:
983980
# Use relative amplitude to threshold.
984981
lwt_min = np.log(wt_thresh) + np.max(lnprob)
@@ -989,7 +986,32 @@ def lnpost(results, parallax=None, parallax_err=None, coord=None,
989986
prob = np.exp(lnprob - logsumexp(lnprob))
990987
cdf = np.cumsum(prob[idx_sort])
991988
sel = idx_sort[cdf <= (1. - cdf_thresh)]
992-
lnp = lnp[sel]
989+
990+
# Apply priors based on MLE solution.
991+
lnp = lnlike[sel]
992+
with warnings.catch_warnings():
993+
warnings.simplefilter("ignore") # ignore bad values
994+
# Add static prior.
995+
lnp += lnprior[sel]
996+
# Evaluate Galactic prior.
997+
dist = 1. / np.sqrt(scales[sel])
998+
lnp += lngalprior(dist, coord, labels=dlabels[sel])
999+
# Evaluate dust prior.
1000+
if apply_av_prior:
1001+
lnp += lndustprior(dist, coord, avs[sel], dustfile=dustfile)
1002+
1003+
# Apply posterior thresholding.
1004+
if wt_thresh is not None:
1005+
# Use relative amplitude to threshold.
1006+
lwt_min = np.log(wt_thresh) + np.max(lnp)
1007+
sel = sel[np.where(lnp > lwt_min)[0]]
1008+
else:
1009+
# Use CDF to threshold.
1010+
idx_sort = np.argsort(lnp)
1011+
prob = np.exp(lnp - logsumexp(lnp))
1012+
cdf = np.cumsum(prob[idx_sort])
1013+
sel = sel[idx_sort[cdf <= (1. - cdf_thresh)]]
1014+
lnp = lnlike[sel] + lnprior[sel]
9931015
scale, av, rv = scales[sel], avs[sel], rvs[sel]
9941016
icov_sar = icovs_sar[sel]
9951017
Nsel = len(sel)
@@ -1093,7 +1115,7 @@ def __init__(self, models, models_labels, labels_mask):
10931115
def _setup(self, data, data_err, data_mask, data_labels=None,
10941116
phot_offsets=None, parallax=None, parallax_err=None,
10951117
av_gauss=None, lnprior=None,
1096-
wt_thresh=5e-3, cdf_thresh=2e-3,
1118+
wt_thresh=1e-3, cdf_thresh=2e-3,
10971119
apply_agewt=True, apply_grad=True,
10981120
lngalprior=None, lndustprior=None, dustfile=None,
10991121
data_coords=None, ltol_subthresh=1e-2,
@@ -1146,7 +1168,7 @@ def _setup(self, data, data_err, data_mask, data_labels=None,
11461168
wt_thresh : float, optional
11471169
The threshold `wt_thresh * max(y_wt)` used to ignore models
11481170
with (relatively) negligible weights when resampling.
1149-
Default is `5e-3`.
1171+
Default is `1e-3`.
11501172
11511173
cdf_thresh : float, optional
11521174
The `1 - cdf_thresh` threshold of the (sorted) CDF used to ignore
@@ -1377,7 +1399,7 @@ def fit(self, data, data_err, data_mask, data_labels, save_file,
13771399
Nmc_prior=50, avlim=(0., 20.), av_gauss=None,
13781400
rvlim=(1., 8.), rv_gauss=(3.32, 0.18),
13791401
lnprior=None, lnprior_ext=None,
1380-
wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
1402+
wt_thresh=1e-3, cdf_thresh=2e-3, Ndraws=250,
13811403
apply_agewt=True, apply_grad=True,
13821404
lngalprior=None, lndustprior=None, dustfile=None,
13831405
apply_dlabels=True, data_coords=None, logl_dim_prior=True,
@@ -1461,7 +1483,7 @@ def fit(self, data, data_err, data_mask, data_labels, save_file,
14611483
wt_thresh : float, optional
14621484
The threshold `wt_thresh * max(y_wt)` used to ignore models
14631485
with (relatively) negligible weights when resampling.
1464-
Default is `5e-3`.
1486+
Default is `1e-3`.
14651487
14661488
cdf_thresh : float, optional
14671489
The `1 - cdf_thresh` threshold of the (sorted) CDF used to ignore
@@ -1748,7 +1770,7 @@ def _fit(self, data, data_err, data_mask,
17481770
avlim=(0., 20.), av_gauss=None,
17491771
rvlim=(1., 8.), rv_gauss=(3.32, 0.18),
17501772
lnprior=None, lnprior_ext=None,
1751-
wt_thresh=5e-3, cdf_thresh=2e-3, Ndraws=250,
1773+
wt_thresh=1e-3, cdf_thresh=2e-3, Ndraws=250,
17521774
lngalprior=None, lndustprior=None, dustfile=None,
17531775
apply_dlabels=True, data_coords=None,
17541776
return_distreds=True, logl_dim_prior=True, ltol=3e-2,
@@ -1812,7 +1834,7 @@ def _fit(self, data, data_err, data_mask,
18121834
wt_thresh : float, optional
18131835
The threshold `wt_thresh * max(y_wt)` used to ignore models
18141836
with (relatively) negligible weights.
1815-
Default is `5e-3`.
1837+
Default is `1e-3`.
18161838
18171839
cdf_thresh : float, optional
18181840
The `1 - cdf_thresh` threshold of the (sorted) CDF used to ignore

0 commit comments

Comments
 (0)