Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update example notebooks on main branch #7

Merged
merged 7 commits into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions jaxtronomy/Inference/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,39 @@ class Loss(object):

def __init__(self, data, image_class, param_class,
likelihood_type='gaussian',
regularization_terms=['starlets_l1']):
# regularization_terms=['starlets_l1'], # TODO
prior_terms=['gaussian']):
self._data = data
self._image = image_class
self._param = param_class
if likelihood_type == 'gaussian':
self._log_likelihood = self._gaussian_log_likelihood
else:
raise NotImplementedError("Likelihood term '{}' ")

raise NotImplementedError(f"Likelihood term '{likelihood_type}' is not supported")
if prior_terms is None or 'none' in prior_terms:
self._log_prior = lambda args: 0.
elif prior_terms == ['uniform']:
self._log_prior = self._param.log_prior_uniform
elif prior_terms == ['gaussian']:
self._log_prior = self._param.log_prior_gaussian
elif 'gaussian' in prior_terms and 'uniform' in prior_terms:
self._log_prior = self._param.log_prior
else:
raise NotImplementedError(f"Prior terms {prior_terms} is not supported")

@partial(jit, static_argnums=(0,))
def __call__(self, args):
return self.loss(args)

@partial(jit, static_argnums=(0,))
def loss(self, args):
model = self._image.model(**self._param.args2kwargs(args))
log_L = self.log_likelihood(model)
log_L = self.log_likelihood(self._image.model(**self._param.args2kwargs(args)))
log_P = self.log_prior(args)
return - log_L - log_P

@partial(jit, static_argnums=(0,))
def loss_kwargs(self, kwargs):
model = self._image.model(**kwargs)
log_L = self.log_likelihood(model)
log_L = self.log_likelihood(self._image.model(**kwargs))
log_P = self.log_prior(self._param.kwargs2args(kwargs))
return - log_L - log_P

Expand All @@ -41,7 +50,7 @@ def log_likelihood(self, model):

@partial(jit, static_argnums=(0,))
def log_prior(self, args):
return self._param.log_prior_no_uniform(args)
return self._log_prior(args)

@partial(jit, static_argnums=(0,))
def _gaussian_log_likelihood(self, model):
Expand Down
2 changes: 1 addition & 1 deletion jaxtronomy/LensImage/lens_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, data_class, psf_class,
lens_light_model_class = LightModel(light_model_list=[])
self.LensLightModel = lens_light_model_class
self._kwargs_numerics = kwargs_numerics
self.source_mapping = Image2SourceMapping(lensModel=lens_model_class, sourceModel=source_model_class)
self.source_mapping = Image2SourceMapping(lens_model_class, source_model_class)

def update_psf(self, psf_class):
"""
Expand Down
2 changes: 2 additions & 0 deletions jaxtronomy/LensModel/Profiles/pixelated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@


class PixelatedPotential(LensProfileBase):
param_names = ['x_coords', 'y_coords', 'psi_grid']

def __init__(self):
"""Lensing potential on a fixed coordinate grid."""
# self.x_coords = x_coords
Expand Down
72 changes: 58 additions & 14 deletions jaxtronomy/Parameters/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Parameters(object):
- nice LaTeX format for parameter names
"""

_bound_penalty = 1e10
_unif_prior_penalty = 1e10

def __init__(self, kwargs_model, kwargs_init, kwargs_prior, kwargs_fixed):
self._kwargs_model = kwargs_model
Expand Down Expand Up @@ -123,25 +123,34 @@ def log_prior(self, args):
gaussian_prior = self._prior_types[i] == 'gaussian'
uniform_prior = self._prior_types[i] == 'uniform'
logP += lax.cond(gaussian_prior, lambda _: - 0.5 * ((args[i] - self._means[i]) / self._widths[i]) ** 2, lambda _: 0., operand=None)
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] < self._lowers[i], lambda _: - self._bound_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] > self._uppers[i], lambda _: - self._bound_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] < self._lowers[i], lambda _: - self._unif_prior_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] > self._uppers[i], lambda _: - self._unif_prior_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
return logP

@partial(jit, static_argnums=(0,))
def log_prior_no_uniform(self, args):
def log_prior_gaussian(self, args):
logP = 0
for i in range(self.num_parameters):
gaussian_prior = self._prior_types[i] == 'gaussian'
logP += lax.cond(gaussian_prior, lambda _: - 0.5 * ((args[i] - self._means[i]) / self._widths[i]) ** 2, lambda _: 0., operand=None)
return logP

@partial(jit, static_argnums=(0,))
def log_prior_uniform(self, args):
logP = 0
for i in range(self.num_parameters):
uniform_prior = self._prior_types[i] == 'uniform'
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] < self._lowers[i], lambda _: - self._unif_prior_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
logP += lax.cond(uniform_prior, lambda _: lax.cond(args[i] > self._uppers[i], lambda _: - self._unif_prior_penalty, lambda _: 0., operand=None), lambda _: 0., operand=None)
return logP

def log_prior_nojit(self, args):
logP = 0
for i in range(self.num_parameters):
if self._prior_types[i] == 'gaussian':
logP += - 0.5 * ((args[i] - self._means[i]) / self._widths[i]) ** 2
elif self._prior_types[i] == 'uniform' and not (self._lowers[i] <= args[i] <= self._uppers[i]):
logP += - self._bound_penalty
logP += - self._unif_prior_penalty
return logP

def _get_params(self, args, i, kwargs_model_key, kwargs_key):
Expand All @@ -158,7 +167,11 @@ def _get_params(self, args, i, kwargs_model_key, kwargs_key):
n_pix_x = len(kwargs_fixed['x_coords'])
n_pix_y = len(kwargs_fixed['y_coords'])
num_param = int(n_pix_x * n_pix_y)
kwargs['image'] = args[i:i + num_param].reshape(n_pix_x, n_pix_y)
if kwargs_key in ['kwargs_source', 'kwargs_lens_light']:
pixels = 'image'
elif kwargs_key == 'kwargs_lens':
pixels = 'psi_grid'
kwargs[pixels] = args[i:i + num_param].reshape(n_pix_x, n_pix_y)
else:
num_param = 1
kwargs[name] = args[i]
Expand All @@ -180,7 +193,11 @@ def _set_params(self, kwargs, kwargs_model_key, kwargs_key):
if name in ['x_coords', 'y_coords']:
raise ValueError(f"'{name}' must be a fixed keyword argument for 'PIXELATED' models")
else:
args += kwargs_profile['image'].flatten().tolist()
if kwargs_key in ['kwargs_source', 'kwargs_lens_light']:
pixels = 'image'
elif kwargs_key == 'kwargs_lens':
pixels = 'psi_grid'
args += kwargs_profile[pixels].flatten().tolist()
else:
args.append(kwargs_profile[name])
return args
Expand All @@ -199,24 +216,36 @@ def _set_params_prior(self, kwargs, kwargs_model_key, kwargs_key):
uppers.append(+np.inf)
means.append(np.nan)
widths.append(np.nan)

else:
prior_type = kwargs_profile[name][0]
if prior_type == 'uniform':
if model in ['PIXELATED']:
if name in ['x_coords', 'y_coords']:
raise ValueError(f"'{name}' must be a fixed keyword argument for 'PIXELATED' models")
if kwargs_key in ['kwargs_source', 'kwargs_lens_light']:
pixels = 'image'
elif kwargs_key == 'kwargs_lens':
pixels = 'psi_grid'
n_pix_x = len(kwargs_fixed['x_coords'])
n_pix_y = len(kwargs_fixed['y_coords'])
num_param = int(n_pix_x * n_pix_y)
types += [prior_type]*num_param
lowers += kwargs_profile['image'][1].flatten().tolist()
uppers += kwargs_profile['image'][2].flatten().tolist()
lowers_tmp, uppers_tmp = kwargs_profile[pixels][1], kwargs_profile[pixels][2]
# those bounds can either be whole array (values per pixel)
if isinstance(lowers_tmp, (np.ndarray, jnp.ndarray)):
lowers += lowers_tmp.flatten().tolist()
uppers += uppers_tmp.flatten().tolist()
# or they can be single numbers, in which case they are considered the same for pixel
elif isinstance(lowers_tmp, (int, float)):
lowers += [float(lowers_tmp)]*num_param
uppers += [float(uppers_tmp)]*num_param
means += [np.nan]*num_param
widths += [np.nan]*num_param
else:
types.append(prior_type)
lowers.append(kwargs_profile[name][1])
uppers.append(kwargs_profile[name][2])
lowers.append(float(kwargs_profile[name][1]))
uppers.append(float(kwargs_profile[name][2]))
means.append(np.nan)
widths.append(np.nan)

Expand All @@ -229,6 +258,7 @@ def _set_params_prior(self, kwargs, kwargs_model_key, kwargs_key):
uppers.append(+np.inf)
means.append(kwargs_profile[name][1])
widths.append(kwargs_profile[name][2])

else:
raise ValueError(f"Prior type '{prior_type}' is not supported")
return types, lowers, uppers, means, widths
Expand Down Expand Up @@ -258,6 +288,9 @@ def _get_param_names_for_model(kwargs_key, model):
elif model == 'SHEAR_GAMMA_PSI':
from jaxtronomy.LensModel.Profiles.shear import ShearGammaPsi
profile_class = Shear
elif model == 'PIXELATED':
from jaxtronomy.LensModel.Profiles.pixelated import PixelatedPotential
profile_class = PixelatedPotential
return profile_class.param_names

def _set_names(self, kwargs_model_key, kwargs_key):
Expand All @@ -271,16 +304,27 @@ def _set_names(self, kwargs_model_key, kwargs_key):
n_pix_x = len(kwargs_fixed['x_coords'])
n_pix_y = len(kwargs_fixed['y_coords'])
num_param = int(n_pix_x * n_pix_y)
names += [f'a_{i}' for i in range(num_param)]
if kwargs_key == 'kwargs_lens_light':
names += [f"d_{i}" for i in range(num_param)] # 'd' for deflector
elif kwargs_key == 'kwargs_source':
names += [f"s_{i}" for i in range(num_param)] # 's' for source
elif kwargs_key == 'kwargs_lens':
names += [f"dpsi_{i}" for i in range(num_param)] # 'dpsi' for potential corrections
else:
names.append(name)
return names

def _name2latex(self, names):
latexs = []
for name in names:
if name[:2] == 'a_': # for 'PIXELATED' models
latex = r"$a_{}$".format(int(name[2:]))
# pixelated models
if name[:2] == 'd_':
latex = r"$d_{" + r"{}".format(int(name[2:])) + r"}$"
elif name[:2] == 's_':
latex = r"$s_{" + r"{}".format(int(name[2:])) + r"}$"
elif name[:5] == 'dpsi_':
latex = r"$\delta\psi_{" + r"{}".format(int(name[5:])) + r"}$"
# other parametric models
elif name == 'theta_E':
latex = r"$\theta_{\rm E}$"
elif name == 'gamma':
Expand Down
Loading