Skip to content

[ENH] Initial cupy implementation to leverage GPU #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions examples/example_experimental_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,34 @@
The available denoising methods are "nordic", "mp-pca", "hybrid-pca", "opt-fro", "opt-nuc" and "opt-op".
"""

from patch_denoise.simulation.phantom import mr_shepp_logan_t2_star, g_factor_map
from patch_denoise.simulation.activations import add_frames
from patch_denoise.simulation.noise import add_temporal_gaussian_noise
import nibabel as nib
from patch_denoise.space_time.lowrank import OptimalSVDDenoiser
import timeit

# %%
# Setup the parameters for the simulation and noise

SHAPE = (64, 64, 64)
N_FRAMES = 200
# SHAPE = (64, 64, 64)
# N_FRAMES = 200

NOISE_LEVEL = 2
# NOISE_LEVEL = 2

base_path = "/data/parietal/store2/data/ibc/"
#input_path = base_path + "3mm/sub-01/ses-00/func/wrdcsub-01_ses-00_task-ArchiSocial_dir-ap_bold.nii.gz"
input_path = base_path + "sourcedata/sub-01/ses-00/func/sub-01_ses-00_task-ArchiSocial_dir-ap_bold.nii.gz"
output_path = "/scratch/ymzayek/retreat_data/output.nii"

img = nib.load(input_path)

print(f"Data shape is {img.shape} with affine \n{img.affine}")

patch_shape = (11, 11, 11)
patch_overlap = (5)

# initialize denoiser
optimal_llr = OptimalSVDDenoiser(patch_shape, patch_overlap)

# denoise image
time_start = timeit.default_timer()
denoised = optimal_llr.denoise(img.get_fdata(), engine="gpu", batch_size=100)
print(timeit.default_timer() - time_start)
115 changes: 113 additions & 2 deletions src/patch_denoise/space_time/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import logging
import numpy as np
from tqdm.auto import tqdm
import cupy as cp

from .._docs import fill_doc

from .utils import get_patch_locs
from .utils import get_patch_locs, get_patches_gpu


@fill_doc
Expand All @@ -33,7 +34,13 @@ def __init__(self, patch_shape, patch_overlap, recombination="weighted"):
self.input_denoising_kwargs = dict()

@fill_doc
def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None):
def denoise(
self,
input_data,
mask=None,
mask_threshold=50,
progbar=None,
):
"""Denoise the input_data, according to mask.

Patches are extracted sequentially and process by the implemented
Expand Down Expand Up @@ -129,8 +136,111 @@ def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None):
noise_std_estimate[patch_slice] += noise_var
# the top left corner of the patch is used as id for the patch.
rank_map[patch_center_img] = maxidx
if progbar:
progbar.update()

# Averaging the overlapping pixels.
# this is only required for averaging recombinations.
if self.recombination in ["average", "weighted"]:
output_data /= patchs_weight[..., None]
noise_std_estimate /= patchs_weight

output_data[~process_mask] = 0

return output_data, patchs_weight, noise_std_estimate, rank_map

def denoise_gpu(
self,
input_data,
mask=None,
mask_threshold=50,
progbar=None,
batch_size=None,
):
data_shape = input_data.shape
output_data = np.zeros_like(input_data)
rank_map = np.zeros(data_shape[:-1], dtype=np.int32)
# Create Default mask
if mask is None:
process_mask = np.full(data_shape[:-1], True)
else:
process_mask = np.copy(mask)

patch_shape, patch_overlap = self.__get_patch_param(data_shape)
patch_size = np.prod(patch_shape)

if self.recombination == "center":
patch_center = (
*(slice(ps // 2, ps // 2 + 1) for ps in patch_shape),
slice(None, None, None),
)
patchs_weight = np.zeros(data_shape[:-1], np.float32)
noise_std_estimate = np.zeros(data_shape[:-1], dtype=np.float32)

# discard useless patches
patch_locs = get_patch_locs(patch_shape, patch_overlap, data_shape[:-1])
get_it = np.zeros(len(patch_locs), dtype=bool)

patch_slices = []
for i, patch_tl in enumerate(patch_locs):
patch_slice = tuple(
slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape)
)
if 100 * np.sum(process_mask[patch_slice]) / patch_size > mask_threshold:
get_it[i] = True
patch_slices.append(patch_slice)

logging.info(f"Denoise {100 * np.sum(get_it) / len(patch_locs):.2f}% patches")
patch_locs = np.ascontiguousarray(patch_locs[get_it])

if progbar is None:
progbar = tqdm(total=len(patch_locs))
elif progbar is not False:
progbar.reset(total=len(patch_locs))

patches = get_patches_gpu(input_data, patch_shape, patch_overlap)
patches[np.isnan(patches)] = np.mean(patches)

patches_denoise, patches_maxidx, noise_var = self._patch_processing_gpu(
patches,
patch_slices=patch_slices,
batch_size=batch_size,
**self.input_denoising_kwargs,
)
patches_denoise = cp.asnumpy(patches_denoise)
patches_maxidx = cp.asnumpy(patches_maxidx)
for patch_tl, p_denoise, maxidx in zip(patch_locs, patches_denoise, patches_maxidx):
#breakpoint()
patch_slice = tuple(
slice(tl, tl + ps) for tl, ps in zip(patch_tl, patch_shape)
)
process_mask[patch_slice] = 1
p_denoise = np.reshape(p_denoise, (*patch_shape, -1))
patch_center_img = tuple(
ptl + ps // 2 for ptl, ps in zip(patch_tl, patch_shape)
)
if self.recombination == "center":
output_data[patch_center_img] = p_denoise[patch_center]
noise_std_estimate[patch_center_img] += noise_var
elif self.recombination == "weighted":
theta = 1 / (2 + maxidx)
output_data[patch_slice] += p_denoise * theta
patchs_weight[patch_slice] += theta
elif self.recombination == "average":
output_data[patch_slice] += p_denoise
patchs_weight[patch_slice] += 1
else:
raise ValueError(
"recombination must be one of 'weighted', 'average', "
"'center'."
)
if not np.isnan(noise_var):
noise_std_estimate[patch_slice] += noise_var
# the top left corner of the patch is used as id for the patch.
rank_map[patch_center_img] = maxidx
if progbar:
progbar.update()

# Averaging the overlapping pixels.
# this is only required for averaging recombinations.
if self.recombination in ["average", "weighted"]:
Expand All @@ -140,6 +250,7 @@ def denoise(self, input_data, mask=None, mask_threshold=50, progbar=None):
output_data[~process_mask] = 0

return output_data, patchs_weight, noise_std_estimate, rank_map


@abc.abstractmethod
def _patch_processing(self, patch, patch_slice=None, **kwargs):
Expand Down
71 changes: 70 additions & 1 deletion src/patch_denoise/space_time/lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import numpy as np
from scipy.linalg import svd
from scipy.optimize import minimize
import cupy as cp

from .base import BaseSpaceTimeDenoiser
from .utils import (
eig_analysis,
eig_synthesis,
marshenko_pastur_median,
svd_analysis,
svd_analysis_gpu,
svd_synthesis,
)
from .._docs import fill_doc
Expand Down Expand Up @@ -320,6 +322,8 @@ def denoise(
noise_std=None,
eps_marshenko_pastur=1e-7,
progbar=None,
engine="cpu",
batch_size=None,
):
"""
Optimal thresholing denoising method.
Expand Down Expand Up @@ -364,7 +368,20 @@ def denoise(
else:
self.input_denoising_kwargs["var_apriori"] = noise_std**2

return super().denoise(input_data, mask, mask_threshold, progbar=progbar)
if engine == "cpu":
return super().denoise(
input_data, mask, mask_threshold, progbar=progbar,
)
elif engine == "gpu":
return super().denoise_gpu(
input_data,
mask,
mask_threshold,
progbar=progbar,
batch_size=batch_size,
)
else:
raise ValueError(f"Unknown engine: {engine}. Use 'cpu' or 'gpu'.")

def _patch_processing(
self,
Expand Down Expand Up @@ -396,6 +413,58 @@ def _patch_processing(

return p_new, maxidx, np.NaN

def _patch_processing_gpu(
self,
patches,
patch_slices=None,
shrink_func=None,
mp_median=None,
var_apriori=None,
batch_size=None,
):
if batch_size is None:
batch_size = patches.shape[0]
u_vec, s_values, v_vec, p_tmean = svd_analysis_gpu(
patches, batch_size=batch_size
)
if var_apriori is not None:
#sigma = cp.empty((batch_size, m, m), dtype=cp.float64)
for patch_slice in patch_slices:
sigma = np.mean(np.sqrt(var_apriori[patch_slice]))
else:
sigma = cp.median(
s_values, axis=1
) / cp.sqrt(patches.shape[-1] * mp_median)

scale_factor = (cp.sqrt(patches.shape[-1]) * sigma)[..., None]
thresh_s_values = scale_factor * shrink_func(
s_values / scale_factor,
beta=patches.shape[-1] / patches.shape[-2],
)
thresh_s_values[cp.isnan(thresh_s_values)] = 0

# Check all batches to see if they have any values above 0
check_any = cp.any(thresh_s_values, axis=1)
indices_true = cp.nonzero(check_any)[0]
indices_false = cp.nonzero(~check_any)[0]
maxidx = cp.zeros(thresh_s_values.shape[0])
p_new = cp.zeros(patches.shape)

if len(indices_true) > 0:
# Get values at nonzero indices and get the max index for each
thresh_s_values_t = thresh_s_values[indices_true, :]
for i in indices_true:
maxidx[i] = cp.max(cp.array(cp.nonzero(thresh_s_values_t[i]))) + 1
p_new[i] = (
u_vec[i, :, :maxidx[i]] @ (
thresh_s_values_t[i, :maxidx[i], None] * v_vec[i, :maxidx[i], :]
)
) + p_tmean[i, :]
if len(indices_false) > 0:
for i in indices_false:
maxidx[i] = 0
p_new[i] = cp.zeros_like(patches[i]) + p_tmean[i, :]


def _sure_atn_cost(X, method, sing_vals, gamma, sigma=None, tau=None):
"""
Expand Down
73 changes: 72 additions & 1 deletion src/patch_denoise/space_time/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from scipy.integrate import quad
from scipy.linalg import eigh, svd
import cupy as cp


def svd_analysis(input_data):
Expand All @@ -19,14 +20,47 @@ def svd_analysis(input_data):
-------
u_vec, s_vals, v_vec, mean
"""
# TODO benchmark svd vs svds and order of data.
mean = np.mean(input_data, axis=0)
data_centered = input_data - mean
# TODO benchmark svd vs svds and order of data.
u_vec, s_vals, v_vec = svd(data_centered, full_matrices=False)

return u_vec, s_vals, v_vec, mean


def svd_analysis_gpu(input_data, batch_size):
total_samples = input_data.shape[0]
num_batches = int(np.ceil(total_samples/ batch_size))
adjusted_batch_size = total_samples // num_batches
last_batch_size = total_samples % adjusted_batch_size

# Initialize arrays to store the results
# input_data shape is (total patches, patch size, time)
m = input_data.shape[1]
n = input_data.shape[2]
U_batched = cp.empty((total_samples, m, n), dtype=cp.float64)
S_batched = cp.empty((total_samples, min(m, n)), dtype=cp.float64)
V_batched = cp.empty((total_samples, n, n), dtype=cp.float64)
mean_batched = cp.empty((total_samples, n), dtype=cp.float64)

# Compute SVD for each matrix in the batch
for i in range(num_batches):
print(i)
start_idx = i * adjusted_batch_size
end_idx = start_idx + adjusted_batch_size if i < num_batches - 1 else start_idx + last_batch_size
idx = slice(start_idx, end_idx)
mean = cp.mean(input_data[idx], axis=1, keepdims=True)
data_centered = cp.asarray(input_data[idx] - mean)
u_vec, s_vals, v_vec = cp.linalg.svd(
data_centered, full_matrices=False
)
U_batched[idx] = u_vec
S_batched[idx] = s_vals
V_batched[idx] = v_vec
mean_batched[idx] = cp.asarray(cp.squeeze(mean))
return U_batched, S_batched, V_batched, mean_batched


def svd_synthesis(u_vec, s_vals, v_vec, mean, idx):
"""
Reconstruct ``X = (U @ (S * V)) + M`` with only the max_idx greatest component.
Expand Down Expand Up @@ -197,6 +231,43 @@ def get_patch_locs(p_shape, p_ovl, v_shape):
return patch_locs.reshape(-1, len(p_shape))


def get_patches_gpu(input_data, patch_shape, patch_overlap):
"""Extract all the patches from a volume.

Returns
-------
numpy.ndarray
All the patches in shape (patches, patch size, time).
"""
patch_size = np.prod(patch_shape)

# Pad the data
input_data = cp.asarray(input_data)

c, h, w, t_s = input_data.shape
kc, kh, kw = patch_shape # kernel size
sc, sh, sw = np.repeat(
patch_shape[0] - patch_overlap[0], len(patch_shape)
)
needed_c = int((cp.ceil((c - kc) / sc + 1) - ((c - kc) / sc + 1)) * kc)
needed_h = int((cp.ceil((h - kh) / sh + 1) - ((h - kh) / sh + 1)) * kh)
needed_w = int((cp.ceil((w - kw) / sw + 1) - ((w - kw) / sw + 1)) * kw)

input_data_padded = cp.pad(
input_data, ((0, needed_c), (0, needed_h), (0, needed_w), (0, 0)
), mode='edge')

step = patch_shape[0] - patch_overlap[0]
patches = cp.lib.stride_tricks.sliding_window_view(
input_data_padded, patch_shape, axis=(0, 1, 2)
)[::step, ::step, ::step]

patches = patches.transpose((0, 1, 2, 4, 5, 6, 3))
patches = patches.reshape((np.prod(patches.shape[:3]), patch_size, t_s))

return cp.asnumpy(patches)


def estimate_noise(noise_sequence, block_size=1):
"""Estimate the temporal noise standard deviation of a noise only sequence."""
volume_shape = noise_sequence.shape[:-1]
Expand Down