Skip to content

Commit

Permalink
Update to miri scattering effect by adding astropy convolution functi…
Browse files Browse the repository at this point in the history
…on. Added NotImplementedError for MIRI LRS to calc_psf and noted in tutorial notebook
  • Loading branch information
shanosborne committed May 23, 2018
1 parent b16807f commit 1b9e1f5
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 133 deletions.
44 changes: 18 additions & 26 deletions notebooks/Distortion_examples.ipynb

Large diffs are not rendered by default.

157 changes: 50 additions & 107 deletions webbpsf/distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as fits
import astropy.convolution

from astropy.table import Table
from scipy.interpolate import griddata
Expand Down Expand Up @@ -223,110 +224,63 @@ def apply_rotation(HDUlist_or_filename=None, rotate_value=None, crop=True):
# #####################################################################################################################


def _get_default_miri(filter):
def _make_kernel(image, amplitude, nsamples):
"""
Store the default values for the MIRI scattering cross artifact distortion transformation. Values come from
MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf
"""

radius = 200 # radius of kernel profile (in MIRI detector pixels)

aper = pysiaf.Siaf("MIRI").apertures["MIRIM_FULL"]
rotate_value = getattr(aper, "V3IdlYAngle") # = 4.4497 # rotation value pulled from most updated SIAF file

filter_list = ['F560W', 'F770W', 'F1000W', 'F1130W', 'F1280W', 'F1500W', 'F1800W', 'F2100W', 'F2550W',
'FND', 'F1065C', 'F1140C', 'F1550C', 'F2300C']

kernel_amp_list = [0.00220, 0.00139, 0.00034, 0.00007, 0.00011, 0.0, 0.0, 0.0, 0.0,
0.00087, 0.00010, 0.00007, 0.0, 0.0] # detector scattering kernel amplitude

# Set PSF values
i_filter = filter_list.index(filter)
kernel_amp = kernel_amp_list[i_filter]

miri_scattering_default = {
"radius": radius,
"kernel_amp": kernel_amp,
"rotate_value": rotate_value
}

return miri_scattering_default


def _make_kernel(amplitude, radius, nsamples):
Creates a detector scatter kernel function. For simplicity, we assume a simple exponential dependence. Code is
adapted from MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf (originally in IDL).
"""
Creates a detector scatter kernel function k(x-a) which defines the fraction of the signal in a pixel at pixel
location 'a' which is scattered into a pixel at position 'x-a' in the along row or along column direction. For
simplicity, we assume a simple exponential dependence.

Code is from MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf (originally in IDL).
"""
# Compute 1d indices
x = np.arange(image.shape[1], dtype=float)
x -= (image.shape[1]-1)/2
x /= nsamples

# Update values based on oversampling of PSF
fold = 25.0 * nsamples # e-folding length is 25 MIRI detector pixels
amplitude /= nsamples # the signal is shared between n samples
radius *= nsamples
# Create 1d kernel
kernel_x = amplitude * np.exp(-np.abs(x) / 25)

# Generate kernel functions for the halo v pixel distance
distances_list = np.arange(2 * radius + 1, dtype=float)
distances_list = np.abs(distances_list - distances_list[radius])
kernel = amplitude * np.exp(-distances_list / fold)
# Reshape kernel to 2D image for use in convolution
kernel_x.shape = (1, image.shape[1])

return kernel, amplitude, fold
return kernel_x


def _apply_kernel(kernel, image, radius):
def _apply_kernel(in_psf, kernel_x, oversample):
"""
Applies the detector scattering kernel function created in _make_kernel function to an input image. Code is
from MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf (originally in IDL).
adapted from MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf (originally in IDL).
While the current code form isn't as elegant as it could be, it will stay like this for the time being until
we can find a convolution method that takes less time to execute than this for loop method.
"""
# Apply the kernel via convolution in both the X and Y direction
# Convolve the input PSF with the kernel for scattering in the X direction
im_conv_x = astropy.convolution.convolve_fft(in_psf, kernel_x, boundary='fill', fill_value=0.0,
normalize_kernel=False, nan_treatment='fill')

shape = image.shape
y_num = shape[0]
x_num = shape[1]
out_image = np.zeros([y_num, x_num])

# Stepping through each index of the original image
for y in np.arange(0, y_num):
for x in np.arange(0, x_num):

x_i1 = x - radius
if x_i1 < 0:
x_i1 = 0
n_left = x - x_i1
x_k1 = radius - n_left
# Transpose to make a kernel for Y and convolve with that too
im_conv_y = astropy.convolution.convolve_fft(in_psf, kernel_x.T, boundary='fill', fill_value=0.0,
normalize_kernel=False, nan_treatment='fill')

x_i2 = x + radius
if x_i2 > (x_num - 1):
x_i2 = x_num - 1
n_right = x_i2 - x
x_k2 = radius + n_right
# Sum together both the X and Y scattering.
# Note, it appears we do need to correct the amplitude for the sampling factor. Might as well do that here.
im_conv_both = (im_conv_x + im_conv_y)/(oversample**2)

out_image[y, x_i1:x_i2] += image[y, x] * kernel[x_k1: x_k2]
return im_conv_both

return out_image


def apply_miri_scattering(HDUlist_or_filename=None, radius=None, kernel_amp=None, rotate_value=None):
def apply_miri_scattering(HDUlist_or_filename=None, kernel_amp=None):
"""
Apply a distortion caused by the MIRI scattering cross artifact effect. Description of distortion and code is
adapted from MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf (originally in IDL).
Apply a distortion caused by the MIRI scattering cross artifact effect. In short we convolve a 2D
exponentially decaying cross to the PSF where the amplitude of the exponential function is determined
by the filter of the PSF. A full description of the distortion and the original code can
be found in MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf
Parameters
----------
HDUlist_or_filename :
A PSF from WebbPSF, either as an HDUlist object or as a filename
radius: float
Radius of kernel profile (MIRI pixels). If set to None, value will be set at 200 pixels. Default = None.
kernel_amp: float
Detector scattering kernel amplitude. If set to None, function will pull the value based on best fit analysis
based on the input PSF's filter. Default = None.
rotate_value: float
The rotation of the MIRI detector in degrees
using the input PSF's filter. Default = None.
"""

Expand All @@ -341,63 +295,52 @@ def apply_miri_scattering(HDUlist_or_filename=None, radius=None, kernel_amp=None
# Create a copy of the PSF
psf = copy.deepcopy(hdu_list)

# Log instrument and detector names
# Log instrument name and filter
instrument = hdu_list[0].header["INSTRUME"].upper()
filter = hdu_list[0].header["FILTER"].upper()
filt = hdu_list[0].header["FILTER"].upper()

if instrument != "MIRI":
raise ValueError("MIRI's Scattering Effect should only be applied to MIRI PSFs")

# Pull default values
miri_scattering_default = _get_default_miri(filter)
# Default kernel amplitude values from modeling in MIRI-TN-00076-ATC_Imager_PSF_Issue_4.pdf
kernel_amp_dict = {'F560W': 0.00220, 'F770W': 0.00139, 'F1000W': 0.00034, 'F1130W': 0.00007, 'F1280W': 0.00011,
'F1500W': 0.0, 'F1800W': 0.0, 'F2100W': 0.0, 'F2550W': 0.0, 'FND': 0.00087, 'F1065C': 0.00010,
'F1140C': 0.00007, 'F1550C': 0.0, 'F2300C': 0.0}

# Set values if not already set by a keyword argument
if radius is None:
radius = miri_scattering_default["radius"]
if kernel_amp is None:
kernel_amp = miri_scattering_default["kernel_amp"]
if rotate_value is None:
rotate_value = miri_scattering_default["rotate_value"]
kernel_amp = kernel_amp_dict[filt]

ext = 2

# Set over-sample value
cdp_samp = psf[ext].header["OVERSAMP"] # the over-sample value for this ext. If det, it'll = 1 so no effect
oversample = psf[ext].header["DET_SAMP"]

# Read in PSF
in_psf = psf[ext].data

# Make the kernel
kernel, amplitude, fold = _make_kernel(kernel_amp, radius, cdp_samp)

# Create scattering images by applying the kernel vertically/horizontally via transposing
x_scattered_image = _apply_kernel(kernel, in_psf, radius)
kernel[radius] = 0.0 # set this value to 0 for the 2nd application so you don't apply this value 2x to 1 point
kernel_x = _make_kernel(in_psf, kernel_amp, oversample)

in_psf_tr = in_psf.T # apply the kernel to the y-direction of the images
y_scattered_image_tr = _apply_kernel(kernel, in_psf_tr, radius) # but your output will still be 1D in x-dir
y_scattered_image = y_scattered_image_tr.T # so then make it vertical to be applied later
# Apply the kernel via convolution in both the X and Y direction to produce a 2D output
im_conv_both = _apply_kernel(in_psf, kernel_x, oversample)

# Rotate the scattering images (but keep same size) so they match the PSF
x_scattered_image_rot = rotate(x_scattered_image, rotate_value, reshape=False)
y_scattered_image_rot = rotate(y_scattered_image, rotate_value, reshape=False)
# Add this 2D scattered light output to the PSF
psf_new = in_psf + im_conv_both

# Add the vertical/horizontal scattering images to the PSF
psf_new = in_psf + x_scattered_image_rot + y_scattered_image_rot
# To ensure conservation of intensity, normalize the psf
psf_new *= in_psf.sum() / psf_new.sum()

# Apply data to correct extensions
psf[ext].data = psf_new

# Now bin down over-sampled PSF to be detector-sampled and re-write ext=3
detector_oversample = psf[ext].header["DET_SAMP"]
psf[3].data = poppy.utils.rebin_array(psf_new, rc=(detector_oversample, detector_oversample))
psf[3].data = poppy.utils.rebin_array(psf_new, rc=(oversample, oversample))

for ext in [2, 3]:

# Set new header keywords
psf[ext].header["MIR_DIST"] = ("True", "MIRI detector scattering applied")
psf[ext].header["KERN_AMP"] = (amplitude, "Kernel Amplitude used in kernel exponential")
psf[ext].header["KERNFOLD"] = (fold, "e-folding length used in kernel exponential")
psf[ext].header["KERN_RAD"] = (radius, "Radius of kernel profile (MIRI pixels)")
psf[ext].header["KERN_AMP"] = (kernel_amp, "Amplitude (A) in kernel function A*exp(-x/B)")
psf[ext].header["KERNFOLD"] = (25, "e-folding length (B) in kernel func A*exp(-x/B)")

return psf
2 changes: 2 additions & 0 deletions webbpsf/webbpsf_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,8 @@ def calc_psf(self, outfile=None, source=None, nlambda=None, monochromatic=None,

# If chosen to add distortion
if add_distortion:
if self.image_mask == "LRS slit" and self.pupil_mask == "P750L LRS grating":
raise NotImplementedError("Distortion is not implemented yet for MIRI LRS mode.")

# Set up new extensions to add distortion to
for ext in [0, 1]:
Expand Down

0 comments on commit 1b9e1f5

Please sign in to comment.