Skip to content

Commit

Permalink
Merge pull request #24 from ahuang314/main
Browse files Browse the repository at this point in the history
Implement ImageNoise and ImageData classes
  • Loading branch information
ahuang314 authored Oct 10, 2024
2 parents 33e36c3 + fbb1f2b commit cc5b303
Show file tree
Hide file tree
Showing 7 changed files with 544 additions and 0 deletions.
Empty file added jaxtronomy/Data/__init__.py
Empty file.
138 changes: 138 additions & 0 deletions jaxtronomy/Data/image_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import numpy as np
from jax import jit, numpy as jnp

from lenstronomy.Util.package_util import exporter
from functools import partial

export, __all__ = exporter()


@export
class ImageNoise(object):
"""Class that deals with noise properties of imaging data."""

# NOTE: JIT-compiled functions need to be recompiled each time a new instance of the class is created.

def __init__(
self,
image_data,
exposure_time=None,
background_rms=None,
noise_map=None,
flux_scaling=1,
gradient_boost_factor=None,
verbose=True,
):
"""
:param image_data: numpy array, pixel data values
:param exposure_time: int or array of size the data; exposure time
(common for all pixels or individually for each individual pixel)
Units of data and exposure map should result in:
number of flux counts = data * exposure_map
:param background_rms: root-mean-square value of Gaussian background noise
:param noise_map: int or array of size the data; joint noise sqrt(variance) of each individual pixel.
Overwrites meaning of background_rms and exposure_time.
:param flux_scaling: scales the model amplitudes to match the imaging data units. This can be used, for example,
when modeling multiple exposures that have different magnitude zero points (or flux normalizations) but demand
the same model normalization
:type flux_scaling: float or int (default=1)
:param gradient_boost_factor: None or float, variance terms added in quadrature scaling with
gradient^2 * gradient_boost_factor. NOTE: NOT supported in Jaxtronomy
"""

# Set exposure time
if exposure_time is None:
if noise_map is None:
raise ValueError(
"Exposure map has not been specified in Noise() class!"
)
else:
# make sure no negative exposure values are present no dividing by zero
self.exp_map = jnp.where(
exposure_time <= 10 ** (-10), 10 ** (-10), exposure_time
)

# Set background rms
if background_rms is None:
if noise_map is None:
raise ValueError(
"rms background value as 'background_rms' not specified!"
)
self.background_rms = np.median(noise_map)
else:
self.background_rms = background_rms

self.data = jnp.array(image_data)
self.flux_scaling = flux_scaling

if noise_map is not None:
assert np.shape(noise_map) == np.shape(image_data)
self._noise_map = jnp.array(noise_map)
else:
self._noise_map = noise_map
if background_rms is not None and exposure_time is not None:
if np.any(background_rms * exposure_time < 1) and verbose is True:
print(
"WARNING! sigma_b*f %s < 1 count may introduce unstable error estimates with a Gaussian"
" error function for a Poisson distribution with mean < 1."
% (background_rms * np.max(exposure_time))
)
self.flux_scaling = flux_scaling

# Covariance matrix of all pixel values in 2d numpy array (only diagonal component)
# The covariance matrix is estimated from the data.
# WARNING: For low count statistics, the noise in the data may lead to biased estimates of the covariance matrix.
if self._noise_map is not None:
self.C_D = self._noise_map**2
else:
self.C_D = covariance_matrix(
self.data,
self.background_rms,
self.exp_map,
)

if gradient_boost_factor is not None:
raise ValueError(
"gradient_boost_factor not supported in JAXtronomy. Please use lenstronomy instead"
)

@partial(jit, static_argnums=0)
def C_D_model(self, model):
"""
:param model: model (same as data but without noise)
:return: estimate of the noise per pixel based on the model flux
"""

if self._noise_map is not None:
return self._noise_map**2
else:
return covariance_matrix(model, self.background_rms, self.exp_map)


@export
@jit
def covariance_matrix(data, background_rms, exposure_map):
"""Returns a diagonal matrix for the covariance estimation which describes the
error.
Notes:
- the exposure map must be positive definite. Values that deviate too much from the mean exposure time will be
given a lower limit to not under-predict the Poisson component of the noise.
- the data must be positive semi-definite for the Poisson noise estimate.
Values < 0 (Possible after mean subtraction) will not have a Poisson component in their noise estimate.
:param data: data array, eg in units of photons/second
:param background_rms: background noise rms, eg. in units (photons/second)^2
:param exposure_map: exposure time per pixel, e.g. in units of seconds
:param gradient_boost_factor: None or float, variance terms added in quadrature scaling with
gradient^2 * gradient_boost_factor
:return: len(d) x len(d) matrix that give the error of background and Poisson components; (photons/second)^2
"""
d_pos = jnp.where(data >= 0, data, 0)
sigma = d_pos / exposure_map + background_rms**2
return sigma
205 changes: 205 additions & 0 deletions jaxtronomy/Data/imaging_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import numpy as np
from jax import jit, numpy as jnp

from lenstronomy.Data.pixel_grid import PixelGrid
from jaxtronomy.Data.image_noise import ImageNoise
from functools import partial

__all__ = ["ImageData"]


class ImageData(PixelGrid, ImageNoise):
"""Class to handle the data, coordinate system and masking, including convolution
with various numerical precisions.
The Data() class is initialized with keyword arguments:
- 'image_data': 2d numpy array of the image data
- 'transform_pix2angle' 2x2 transformation matrix (linear) to transform a pixel shift into a coordinate shift (x, y) -> (ra, dec)
- 'ra_at_xy_0' RA coordinate of pixel (0,0)
- 'dec_at_xy_0' DEC coordinate of pixel (0,0)
optional keywords for shifts in the coordinate system:
- 'ra_shift': shifts the coordinate system with respect to 'ra_at_xy_0'
- 'dec_shift': shifts the coordinate system with respect to 'dec_at_xy_0'
optional keywords for noise properties:
- 'background_rms': rms value of the background noise
- 'exp_time': float, exposure time to compute the Poisson noise contribution
- 'exposure_map': 2d numpy array, effective exposure time for each pixel. If set, will replace 'exp_time'
- 'noise_map': Gaussian noise (1-sigma) for each individual pixel.
If this keyword is set, the other noise properties will be ignored.
optional keywords for interferometric quantities:
- 'likelihood_method': need to be specified to 'interferometry_natwt' if one needs to use the interferometric likelihood function.
The default of 'likelihood_method' is 'diagonal', which is used for non-correlated noises (usually for the CCD images.)
- 'log_likelihood_constant': a constant that adds to logL.
- 'antenna_primary_beam': primary beam pattern of antennae (now treat each antenna dish with the same primary beam).
** notes **
the likelihood for the data given model P(data|model) is defined in the function below. Please make sure that
your definitions and units of 'exposure_map', 'background_rms' and 'image_data' are in accordance with the
likelihood function. In particular, make sure that the Poisson noise contribution is defined in the count rate.
"""

# NOTE: JIT-compiled functions need to be recompiled each time a new instance of the class is created.

def __init__(
self,
image_data,
exposure_time=None,
background_rms=None,
noise_map=None,
gradient_boost_factor=None,
ra_at_xy_0=0,
dec_at_xy_0=0,
transform_pix2angle=None,
ra_shift=0,
dec_shift=0,
phi_rot=0,
log_likelihood_constant=0,
antenna_primary_beam=None,
likelihood_method="diagonal",
flux_scaling=1,
):
"""
:param image_data: 2d numpy array of the image data
:param exposure_time: int or array of size the data; exposure time
(common for all pixels or individually for each individual pixel)
:param background_rms: root-mean-square value of Gaussian background noise in units counts per second
:param noise_map: int or array of size the data; joint noise sqrt(variance) of each individual pixel.
:param gradient_boost_factor: None or float, variance terms added in quadrature scaling with
gradient^2 * gradient_boost_factor
:param transform_pix2angle: 2x2 matrix, mapping of pixel to coordinate
:param ra_at_xy_0: ra coordinate at pixel (0,0)
:param dec_at_xy_0: dec coordinate at pixel (0,0)
:param ra_shift: RA shift of pixel grid
:param dec_shift: DEC shift of pixel grid
:param log_likelihood_constant: float, allows user to input a constant that will be added to the log likelihood. Note that, as for now, this variable is ONLY used for interferometric mode.
:param antenna_primary_beam: 2d numpy array with the same size of imaga_data;
:param phi_rot: rotation angle in regard to pixel coordinate transform_pix2angle
:param antenna_primary_beam: 2d numpy array with the same size of image_data;
more descriptions of the primary beam can be found in the AngularSensitivity class
:param likelihood_method: string, type of method of log_likelihood computation: options are 'diagonal', 'interferometry_natwt'.
The default option 'diagonal' uses a diagonal covariance matrix, which is the case for CCD images.
The 'interferometry_natwt' option uses our special interferometric likelihood function based on natural weighting images.
:param flux_scaling: scales the model amplitudes to match the imaging data units. This can be used, for example,
when modeling multiple exposures that have different magnitude zero points (or flux normalizations) but demand
the same model normalization
"""
nx, ny = np.shape(image_data)
if transform_pix2angle is None:
transform_pix2angle = np.array([[1, 0], [0, 1]])
cos_phi, sin_phi = np.cos(phi_rot), np.sin(phi_rot)
rot_matrix = np.array([[cos_phi, -sin_phi], [sin_phi, cos_phi]])
transform_pix2angle_rot = np.dot(transform_pix2angle, rot_matrix)
PixelGrid.__init__(
self,
nx,
ny,
transform_pix2angle_rot,
ra_at_xy_0 + ra_shift,
dec_at_xy_0 + dec_shift,
antenna_primary_beam,
)
ImageNoise.__init__(
self,
image_data,
exposure_time=exposure_time,
background_rms=background_rms,
noise_map=noise_map,
gradient_boost_factor=gradient_boost_factor,
verbose=False,
flux_scaling=flux_scaling,
)

self._logL_constant = log_likelihood_constant
self._logL_method = likelihood_method
if (
self._logL_method != "diagonal"
and self._logL_method != "interferometry_natwt"
):
raise ValueError(
"likelihood_method %s not supported! likelihood_method can only be 'diagonal' or 'interferometry_natwt'!"
% self._logL_method
)

def update_data(self, image_data):
raise Exception(
"Cannot update data when using JAX. Instead, a new instance of ImageData must be created."
)

@partial(
jit,
static_argnums=0,
)
def log_likelihood(self, model, mask, additional_error_map=0):
"""Computes the likelihood of the data given the model p(data|model) The
Gaussian errors are estimated with the covariance matrix, based on the model
image. The errors include the background rms value and the exposure time to
compute the Poisson noise level (in Gaussian approximation).
:param model: the model (same dimensions and units as data)
:param mask: bool (1, 0) values per pixel. If =0, the pixel is ignored in the
likelihood
:param additional_error_map: additional error term (in same units as covariance
matrix). This can e.g. come from model errors in the PSF estimation.
:return: the natural logarithm of the likelihood p(data|model)
"""
# if the likelihood method is assigned to be 'interferometry_natwt', it will return logL computed using the interfermetric likelihood function
if self._logL_method == "interferometry_natwt":
return self.log_likelihood_interferometry(model)

c_d = self.C_D_model(model)
chi2 = (model - self.data) ** 2
chi2 = chi2 / (c_d + jnp.abs(additional_error_map))
chi2 = chi2 * mask
chi2 = jnp.array(chi2)
log_likelihood = -jnp.sum(chi2) / 2
return log_likelihood

@partial(jit, static_argnums=0)
def log_likelihood_interferometry(self, model):
"""log_likelihood function for natural weighting interferometric images, based
on (placeholder for Nan Zhang's paper).
For the interferometry case, the model should be in the form [array1, array2],
where array1 and array2 are unconvolved and convolved model images respectively.
They are both 2d array with the same shape of the data.
The chi^2 of interferometry is computed by
.. math::
\\chi^2 = (d-Ax)^TC^{-1}(d-Ax) = \\frac{1}{\\sigma^2}(d^TA^{-1}d - 2x^Td + x^TAx)
where :math:`d` and :math:`x` are the data vector and the unconvolved model image vector respectively.
:math:`A` is the convolution operation matrix, where we normalize the PSF by setting its central pixel to 1.
:math:`C` is the noise covariance matrix, its diagonal entries are rms^2 of noises, :math:`\\sigma^2`.
For natural weighting interferometric images, we used the relation
(see Section 3.2 of https://doi.org/10.1093/mnras/staa2740 for the relation of natural weighting covariance matrix and PSF convolution)
.. math::
C = \\sigma^2 A
to simplify the likelihood function above.
"""

xd = jnp.sum(model[0] * self.data)
xAx = jnp.sum(model[0] * model[1])
logL = -(xAx - 2 * xd) / (2 * self.background_rms**2) + self._logL_constant
return logL

def likelihood_method(self):
"""Pass the likelihood_method to the ImageModel and will be used to identify the
method of likelihood computation in ImageLinearFit.
:return: string, likelihood method
"""
return self._logL_method
Empty file.
Empty file added test/test_Data/__init__.py
Empty file.
Loading

0 comments on commit cc5b303

Please sign in to comment.