Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "Temporal" option to pca fullfr #583

Merged
merged 8 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/test_pipeline_adi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def algo_nmf_annular(ds):
def algo_pca(ds):
return vip.psfsub.pca(ds.cube, ds.angles, svd_mode='arpack')

def algo_pca_left_eigv(ds):
return vip.psfsub.pca(ds.cube, ds.angles, left_eigv=True)

def algo_pca_linalg(ds):
return vip.psfsub.pca(ds.cube, ds.angles, svd_mode='eigen')

Expand Down Expand Up @@ -163,6 +166,7 @@ def verify_expcoord(vectory, vectorx, exp_yx):
(algo_frdiff, snrmap_fast),
(algo_frdiff4, snrmap_fast),
(algo_pca, snrmap_fast),
(algo_pca_left_eigv, snrmap_fast),
(algo_pca_linalg, snrmap_fast),
(algo_pca_drot, snrmap_fast),
(algo_pca_cevr, snrmap_fast),
Expand Down
8 changes: 8 additions & 0 deletions tests/test_pipeline_sdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,17 @@ def algo_pca_single(ds, sc):
return vip.psfsub.pca(ds.cube, ds.angles, scale_list=sc,
adimsdi='single', ncomp=10)

def algo_pca_single_left_eigv(ds, sc):
return vip.psfsub.pca(ds.cube, ds.angles, scale_list=sc,
adimsdi='single', ncomp=10, left_eigv=True)

def algo_pca_double(ds, sc):
return vip.psfsub.pca(ds.cube, ds.angles, scale_list=sc,
adimsdi='double', ncomp=(1, 2))

def algo_pca_double_left_eigv(ds, sc):
return vip.psfsub.pca(ds.cube, ds.angles, scale_list=sc,
adimsdi='double', ncomp=(1, 2), left_eigv=True)

def algo_pca_annular(ds, sc):
return vip.psfsub.pca_annular(ds.cube, ds.angles, scale_list=sc,
Expand Down Expand Up @@ -156,7 +162,9 @@ def verify_expcoord(vectory, vectorx, exp_yx):
(algo_xloci, snrmap_fast),
(algo_xloci_double, snrmap_fast),
(algo_pca_single, snrmap_fast),
(algo_pca_single_left_eigv, snrmap_fast),
(algo_pca_double, snrmap_fast),
(algo_pca_double_left_eigv, snrmap_fast),
(algo_pca_annular, None),
],
ids=lambda x: (x.__name__.replace("algo_", "") if callable(x) else x))
Expand Down
66 changes: 54 additions & 12 deletions vip_hci/psfsub/pca_fullfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def pca(
verbose=True,
weights=None,
conv=False,
left_eigv=False,
cube_sig=None,
**rot_options
):
Expand Down Expand Up @@ -312,6 +313,9 @@ def pca(
weights: 1d numpy array or list, optional
Weights to be applied for a weighted mean. Need to be provided if
collapse mode is 'wmean'.
left_eigv : bool, optional
Whether to use rather left or right singularvectors
This mode is not compatible with 'mask_rdi' and 'batch'
cube_sig: numpy ndarray, opt
Cube with estimate of significant authentic signals. If provided, this
will subtracted before projecting cube onto reference cube.
Expand All @@ -320,7 +324,7 @@ def pca(
"edge_blend", "interp_zeros", "ker" (see documentation of
``vip_hci.preproc.frame_rotate``)

Returns
Return
-------
frame : numpy ndarray
2D array, median combination of the de-rotated/re-scaled residuals cube.
Expand Down Expand Up @@ -371,6 +375,12 @@ def pca(
"with the full path on disk"
)

if left_eigv :
if (batch is not None or mask_rdi is not None or cube_ref is not None):
raise NotImplementedError( "left_eigv is not compatible"
"with 'mask_rdi' nor 'batch'"
)

# checking memory (if in-memory numpy array is provided)
if not isinstance(cube, str):
input_bytes = cube_ref.nbytes if cube_ref is not None else cube.nbytes
Expand Down Expand Up @@ -416,6 +426,7 @@ def pca(
conv,
mask_rdi,
cube_sig,
left_eigv,
**rot_options
)
residuals_cube_channels, residuals_cube_channels_, frame = res_pca
Expand Down Expand Up @@ -443,6 +454,7 @@ def pca(
batch,
full_output=True,
weights=weights,
left_eigv=left_eigv,
**rot_options
)
if isinstance(ncomp, (int, float)):
Expand Down Expand Up @@ -534,6 +546,7 @@ def pca(
True,
weights,
cube_sig,
left_eigv,
**rot_options
)
if batch is None:
Expand Down Expand Up @@ -630,6 +643,7 @@ def pca(
True,
weights,
cube_sig,
left_eigv,
**rot_options
)

Expand Down Expand Up @@ -746,6 +760,7 @@ def _adi_pca(
full_output,
weights=None,
cube_sig=None,
left_eigv=False,
**rot_options
):
"""Handle the ADI PCA post-processing."""
Expand Down Expand Up @@ -803,14 +818,15 @@ def _adi_pca(
verbose,
full_output,
cube_sig=cube_sig,
left_eigv=left_eigv,
)
if verbose:
timing(start_time)
if full_output:
residuals_cube = residuals_result[0]
reconstructed = residuals_result[1]
V = residuals_result[2]
pcs = reshape_matrix(V, y, x)
pcs = reshape_matrix(V, y, x) if not left_eigv else V.T
recon = reshape_matrix(reconstructed, y, x)
else:
residuals_cube = residuals_result
Expand Down Expand Up @@ -857,6 +873,7 @@ def _adi_pca(
ind,
frame,
cube_sig=cube_sig,
left_eigv=left_eigv
)
if full_output:
nfrslib.append(res_result[0])
Expand Down Expand Up @@ -944,6 +961,7 @@ def _adimsdi_singlepca(
batch,
full_output,
weights=None,
left_eigv=False,
**rot_options
):
"""Handle the full-frame ADI+mSDI single PCA post-processing."""
Expand Down Expand Up @@ -1004,8 +1022,8 @@ def _adimsdi_singlepca(
# When ncomp is a int/float and batch is None, standard ADI-PCA is run
else:
res_cube = _project_subtract(
big_cube, None, ncomp, scaling, mask_center_px, svd_mode, verbose, False
)
big_cube, None, ncomp, scaling, mask_center_px, svd_mode,
verbose, False, left_eigv=left_eigv)

if verbose:
timing(start_time)
Expand Down Expand Up @@ -1112,6 +1130,7 @@ def _adimsdi_doublepca(
conv=False,
mask_rdi=None,
cube_sig=None,
left_eigv=False,
**rot_options
):
"""Handle the full-frame ADI+mSDI double PCA post-processing."""
Expand Down Expand Up @@ -1170,6 +1189,7 @@ def _adimsdi_doublepca(
fwhm,
conv,
mask_rdi,
left_eigv,
)
residuals_cube_channels = np.array(res)

Expand Down Expand Up @@ -1212,6 +1232,7 @@ def _adimsdi_doublepca(
verbose=False,
full_output=False,
cube_sig=cube_sig,
left_eigv=left_eigv,
)
if verbose:
print("De-rotating and combining residuals")
Expand Down Expand Up @@ -1244,6 +1265,8 @@ def _adimsdi_doublepca_ifs(
fwhm,
conv,
mask_rdi=None,
left_eigv=False,

):
"""Call by _adimsdi_doublepca with pool_map."""
global ARRAY
Expand Down Expand Up @@ -1280,6 +1303,7 @@ def _adimsdi_doublepca_ifs(
svd_mode,
verbose=False,
full_output=False,
left_eigv=left_eigv,
)
else:
residuals = np.zeros_like(cube_resc)
Expand Down Expand Up @@ -1402,6 +1426,7 @@ def _project_subtract(
indices=None,
frame=None,
cube_sig=None,
left_eigv=False,
):
"""
PCA projection and model PSF subtraction.
Expand All @@ -1426,6 +1451,8 @@ def _project_subtract(
Verbosity.
full_output : bool
Whether to return intermediate arrays or not.
left_eigv : bool, optional
Whether to use rather left or right singularvectors
indices : list
Indices to be used to discard frames (a rotation threshold is used).
frame : int
Expand All @@ -1446,8 +1473,9 @@ def _project_subtract(
reconstructed : numpy ndarray
[full_output=True] The reconstructed array.
V : numpy ndarray
[full_output=True, indices is None, frame is None] The right singular
vectors of the input matrix, as returned by ``svd/svd_wrapper()``
[full_output=True, indices is None, frame is None]
The right singular vectors of the input matrix, as returned by
``svd/svd_wrapper()``
"""
_, y, x = cube.shape
if isinstance(ncomp, (int, np.int_)):
Expand Down Expand Up @@ -1487,22 +1515,36 @@ def _project_subtract(
)
curr_frame = matrix[frame] # current frame
curr_frame_emp = matrix_emp[frame]
V = svd_wrapper(ref_lib, svd_mode, ncomp, False)
transformed = np.dot(curr_frame_emp, V.T)
reconstructed = np.dot(transformed.T, V)
if left_eigv :
V = svd_wrapper(ref_lib, svd_mode, ncomp, False, left_eigv=left_eigv)
transformed = np.dot(curr_frame_emp.T, V)
reconstructed = np.dot(V, transformed.T)
else :
V = svd_wrapper(ref_lib, svd_mode, ncomp, False)
transformed = np.dot(curr_frame_emp, V.T)
reconstructed = np.dot(transformed.T, V)

residuals = curr_frame - reconstructed

if full_output:
return ref_lib.shape[0], residuals, reconstructed
else:
return ref_lib.shape[0], residuals

# the whole matrix is processed at once
else:
V = svd_wrapper(ref_lib, svd_mode, ncomp, verbose)
transformed = np.dot(V, matrix_emp.T)
reconstructed = np.dot(transformed.T, V)
if left_eigv :
V = svd_wrapper(ref_lib, svd_mode, ncomp, verbose, left_eigv=left_eigv)
transformed = np.dot(matrix_emp.T, V)
reconstructed = np.dot(V, transformed.T)
else :
V = svd_wrapper(ref_lib, svd_mode, ncomp, verbose)
transformed = np.dot(V, matrix_emp.T)
reconstructed = np.dot(transformed.T, V)

residuals = matrix - reconstructed
residuals_res = reshape_matrix(residuals, y, x)

if full_output:
return residuals_res, reconstructed, V
else:
Expand Down
13 changes: 12 additions & 1 deletion vip_hci/psfsub/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def cevr_to_ncomp(self, cevr=0.9):


def svd_wrapper(matrix, mode, ncomp, verbose, full_output=False,
random_state=None, to_numpy=True):
random_state=None, to_numpy=True, left_eigv=False):
""" Wrapper for different SVD libraries (CPU and GPU).

Parameters
Expand Down Expand Up @@ -392,8 +392,12 @@ def svd_wrapper(matrix, mode, ncomp, verbose, full_output=False,
If None, the random number generator is the RandomState instance used
by np.random. Used for ``randsvd`` mode.
to_numpy : bool, optional
Whether to return intermediate arrays or not.
If True (by default) the arrays computed in GPU are transferred from
VRAM and converted to numpy ndarrays.
left_eigv : bool, optional
Whether to use rather left or right singularvectors


Returns
-------
Expand Down Expand Up @@ -581,6 +585,13 @@ def svd_wrapper(matrix, mode, ncomp, verbose, full_output=False,
return S, V
else:
return U, S, V
elif left_eigv:
VChristiaens marked this conversation as resolved.
Show resolved Hide resolved
if mode == 'lapack':
return V.T
elif mode == 'pytorch':
return V
else:
return U
else:
if mode == 'lapack':
return U.T
Expand Down