Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/open_dive/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def main():
"Diffusion glyph options (tensors and ODFs)"
)
window_group = parser.add_argument_group("Window options")
roi_mask_group = parser.add_argument_group("ROI mask options")

scalar_group.add_argument(
"-n",
Expand Down Expand Up @@ -65,6 +66,12 @@ def main():
type=Path,
help="Path to binary mask to generate a glass brain.",
)
scalar_group.add_argument(
"--glass_brain_opacity",
type=float,
default=0.33,
help="Opacity of the glass brain in range (0, 1]. Default is 0.33.",
)

tractography_group.add_argument(
## plot tractogram with slices
Expand Down Expand Up @@ -101,6 +108,58 @@ def main():
action="store_true",
help="Whether to show a tractography values colorbar. Default is False.",
)
tractography_group.add_argument(
"--tractography_is_categorical_values",
action="store_true",
help="If provided, the tractography values are treated as categorical values. Default is False.",
)
tractography_group.add_argument(
"--tractography_categorical_values",
type=str,
nargs="+",
help="List of categorical values for tractography. Must match number of tractography files.",
)
tractography_group.add_argument(
"--tractography_categorical_reference_values",
type=str,
nargs="+",
help="List of reference categorical values for tractography.",
)

roi_mask_group.add_argument(
"--roi_mask_path",
type=Path,
nargs="+", # Accept one or more arguments
help="Path to binary mask(s) to plot as ROI mask(s).",
)
roi_mask_group.add_argument(
"--roi_mask_values",
type=float,
nargs="+",
help="Values to use for coloring each ROI mask (must match number of ROI mask files)",
)
roi_mask_group.add_argument(
"--roi_mask_cmap",
help='Matplotlib or cmcrameri colormap to use for ROI mask. Default is "plasma" if --roi_mask_values is provided, otherwise "Set1".',
)
roi_mask_group.add_argument(
"--roi_mask_cmap_range",
type=float,
nargs=2,
help="Range to use for the colormap. Default is (0, 1).",
)
roi_mask_group.add_argument(
"--roi_mask_opacity",
type=float,
nargs="+",
default=[0.33],
help="Value to use for the ROI mask opacity in range (0, 1]. If a list, each value corresponds to a ROI mask in --roi_mask_path. Default is 0.33.",
)
roi_mask_group.add_argument(
"--roi_mask_colorbar",
action="store_true",
help="Whether to show a ROI mask values colorbar. Default is False.",
)

glyph_group.add_argument(
"--tensor_path",
Expand Down Expand Up @@ -177,11 +236,21 @@ def main():
tractography_cmap=args.tractography_cmap,
tractography_cmap_range=args.tractography_cmap_range,
tractography_colorbar=args.tractography_colorbar,
tractography_is_categorical_values=args.tractography_is_categorical_values,
tractography_categorical_values=args.tractography_categorical_values,
tractography_categorical_reference_values=args.tractography_categorical_reference_values,
tensor_path=args.tensor_path,
odf_path=args.odf_path,
sh_basis=args.sh_basis,
scale=args.scale,
azimuth=args.azimuth,
elevation=args.elevation,
glass_brain_path=args.glass_brain,
glass_brain_opacity=args.glass_brain_opacity,
roi_mask_path=args.roi_mask_path,
roi_mask_values=args.roi_mask_values,
roi_mask_cmap=args.roi_mask_cmap,
roi_mask_cmap_range=args.roi_mask_cmap_range,
roi_mask_opacity=args.roi_mask_opacity,
roi_mask_colorbar=args.roi_mask_colorbar
)
184 changes: 174 additions & 10 deletions src/open_dive/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ def plot_nifti(
tractography_cmap: str | None = None,
tractography_cmap_range: tuple[int, int] | None = None,
tractography_colorbar: bool = False,
tractography_is_categorical_values: bool = False,
tractography_categorical_values: list[str] | None = None,
tractography_categorical_reference_values: list[str] | None = None,
volume_idx: int | None = None,
tensor_path: os.PathLike | None = None,
odf_path: os.PathLike | None = None,
sh_basis: str = "descoteaux07",
scale: int = 1,
glass_brain_path: os.PathLike | None = None,
glass_brain_opacity: float | None = 0.33,
roi_mask_cmap: str | None = None,
roi_mask_cmap_range: tuple[int, int] | None = None,
roi_mask_values: list[float] | None = None,
roi_mask_colorbar: bool = False,
roi_mask_path: list[os.PathLike] | None = None,
roi_mask_opacity: list[float] | None = [0.33],
**kwargs,
) -> None:
"""Create a 2D rendering of a NIFTI slice.
Expand Down Expand Up @@ -87,6 +97,12 @@ def plot_nifti(
Optional range to use for the colormap
tractography_colorbar : bool, default False
Whether to show a colorbar for the tractography
tractography_is_categorical_values : bool, default False
Whether the tractography values are categorical (discrete) or continuous
tractography_categorical_values : list of str, optional
Optional categorical values to color the tractography with
tractography_categorical_reference_values : list of str, optional
Optional reference values for categorical tractography values, used to map colors
volume_idx : int, optional
Index of the volume to display if the image is 4D
tensor_path : os.PathLike, optional
Expand All @@ -99,7 +115,20 @@ def plot_nifti(
Scale of the tensor glyphs or ODF glyphs
glass_brain_path : os.PathLike, optional
Optional glass brain mask to overlay

glass_brain_opacity : float, default 0.33
Opacity of the glass brain mask
roi_mask_cmap : str, default "Set1" or "plasma"
Optional colormap to use for the ROI mask, by default "Set1" if roi_mask_values, otherwise "plasma"
roi_mask_cmap_range : tuple of float, default (0, 1) if roi_mask_values is not None
Optional range to use for the colormap
roi_mask_values : list of float, optional
Optional values to color the ROI mask with
roi_mask_colorbar : bool, default False
Whether to show a colorbar for the ROI mask
roi_mask_path : list of os.PathLike, optional
Optional ROI mask(s) to plot with slices. Can provide multiple files
roi_mask_opacity : list of float, default [0.33]
Optional opacity value for the ROI masks between (0, 1)
**kwargs
Additional keyword arguments to pass to fury.actor.slicer
"""
Expand All @@ -122,10 +151,30 @@ def plot_nifti(
if tractography_cmap is None:
tractography_cmap = "Set1" if tractography_values is None else "plasma"
if tractography_cmap_range is None:
tractography_cmap_range = (
(0, 1) if tractography_values is None else (min(tractography_values), max(tractography_values))
)
tractography_cbar_labels = tractography_values is not None
if not tractography_is_categorical_values:
tractography_cmap_range = (
(0, 1) if tractography_values is None else (min(tractography_values), max(tractography_values))
)
tractography_cbar_labels = tractography_values is not None
else:
if tractography_categorical_reference_values is not None:
tractography_unique_values = list(dict.fromkeys(tractography_categorical_reference_values))
print(tractography_categorical_reference_values)
assert all(val in tractography_unique_values for val in tractography_categorical_values), f"All categorical values must be in the reference values: reference: {tractography_categorical_reference_values} \nvalues: {tractography_categorical_values}"
else:
tractography_unique_values = np.unique(tractography_categorical_values) if tractography_categorical_values is not None else None
tractography_cmap_range = (0, len(tractography_unique_values) - 1) if tractography_unique_values is not None else (0, 1)
tractography_cbar_labels = tractography_categorical_reference_values is not None
print(tractography_unique_values)

#same for ROI mask values
if roi_mask_cmap is None:
roi_mask_cmap = "Set1" if roi_mask_values is None else "plasma"
if roi_mask_cmap_range is None:
roi_mask_cmap_range = (
(0, 1) if roi_mask_values is None else (min(roi_mask_values), max(roi_mask_values))
)
roi_mask_cbar_labels = roi_mask_values is not None

# Set up scene and bounds
scene = window.Scene()
Expand Down Expand Up @@ -177,14 +226,58 @@ def plot_nifti(
)
scene.add(scalar_bar)

#add roi masks
if roi_mask_path is not None:
cmap = plt.get_cmap(roi_mask_cmap)

# Set to range
if roi_mask_values is not None:
norm = plt.Normalize(vmin=roi_mask_cmap_range[0], vmax=roi_mask_cmap_range[1])
colors = [cmap(norm(val)) for val in roi_mask_values]
else:
colors = [cmap(i) for i in range(len(roi_mask_path))]

# Apply colorbar
if roi_mask_colorbar:
roi_bar = _create_colorbar_actor(
value_range=roi_mask_cmap_range,
colorbar_position=(0.1, 0.1),
colorbar_height=0.5,
colorbar_width=0.1,
cmap=cmap,
labels=roi_mask_cbar_labels,
)
scene.add(roi_bar)

# Add each ROI mask with its corresponding color
roi_actors = _create_roi_mask_actor(
mask_nifti=roi_mask_path,
colors=colors,
mask_opacities=roi_mask_opacity
)
for roi_actor in roi_actors:
scene.add(roi_actor)

# Add tractography
if tractography_path is not None:
cmap = plt.get_cmap(tractography_cmap)

# Set to range
if tractography_values is not None:
norm = plt.Normalize(vmin=tractography_cmap_range[0], vmax=tractography_cmap_range[1])
colors = [cmap(norm(val)) for val in tractography_values]
if not tractography_is_categorical_values:
norm = plt.Normalize(vmin=tractography_cmap_range[0], vmax=tractography_cmap_range[1])
colors = [cmap(norm(val)) for val in tractography_values]
elif tractography_categorical_values is not None:
colors_lst = plt.cm.jet(np.linspace(0, 1, len(tractography_unique_values)))
colors_idx_map = {idx:color for idx,color in enumerate(colors_lst)}
tractography_unique_values_idx_map = {val:idx for idx, val in enumerate(tractography_unique_values)}
print(tractography_unique_values)
print(tractography_unique_values_idx_map)
colors = [colors_idx_map[tractography_unique_values_idx_map[val]] for val in tractography_categorical_values]
#print the color mapping for the values
print("Color mapping being used:")
for k,v in tractography_unique_values_idx_map.items():
print(f"{k}: {colors_idx_map[v]}")
else:
colors = [cmap(i) for i in range(len(tractography_path))]

Expand Down Expand Up @@ -231,7 +324,8 @@ def plot_nifti(
odf_actor.display_extent(*extent)

if glass_brain_path:
glass_brain_actor = _create_glass_brain_actor(glass_brain_path)
print(glass_brain_opacity)
glass_brain_actor = _create_glass_brain_actor(glass_brain_path, opacity=glass_brain_opacity)
scene.add(glass_brain_actor)

if scene_bound_data is None:
Expand Down Expand Up @@ -301,11 +395,74 @@ def _create_glass_brain_actor(
# Step 4: Dilate the thresholded mask with 2 passes
mask_dilated = binary_dilation(mask_thres, iterations=dilation_iters).astype(np.uint8)

mask_final = mask_dilated - mask_thres

# Create a surface actor
glass_brain_actor = contour_from_roi(mask_dilated, affine=new_affine, opacity=opacity, color=(0.5, 0.5, 0.5))
glass_brain_actor = contour_from_roi(mask_final, affine=new_affine, opacity=opacity, color=(0.5, 0.5, 0.5))
return glass_brain_actor


def _create_roi_mask_actor(
mask_nifti: list[os.PathLike],
colors: list[tuple[float, float, float]],
mask_opacities: list[float] = [0.33],
resample_factor: int = 2,
smooth_sigma: float = 2,
dilation_iters: int = 2,
) -> Actor:
"""Create "glass ROI" visualizations from a binary masks.

Parameters
----------
mask_nifti : os.PathLike
Path to binary mask NIFTI image
resample_factor : int, default 3
Factor to upsample the mask by
smooth_sigma : float, default 2
Standard deviation for Gaussian smoothing
dilation_iters : int, default 2
Number of iterations for binary dilation
mask_opacities : list[float], default [0.33]
Opacities of the ROI masks
colors : list[tuple[float, float, float]], default [(0.5, 0.5, 0.5)]
Colors of the ROI masks

Returns
-------
glass_brain : fury.actor.surface
ROI mask surface actor
"""

roi_actors = []
roi_opacities = mask_opacities * len(mask_nifti) if len(mask_opacities) == 1 else mask_opacities
for i,(mask_file, color) in enumerate(zip(mask_nifti, colors)):
#load the mask
mask_nifti = nib.load(mask_file)
mask = mask_nifti.get_fdata()
affine = mask_nifti.affine
zooms = mask_nifti.header.get_zooms()[:3]

# Step 1: Upsample (regrid) the mask by a factor of 5
new_zooms = tuple(z / resample_factor for z in zooms)
mask_up, new_affine = reslice(mask, affine, zooms, new_zooms)

# Step 2: Apply Gaussian smoothing with standard deviation 2
mask_smooth = gaussian_filter(mask_up, sigma=smooth_sigma)

# Step 3: Threshold the smoothed mask at 0.5
mask_thres = (mask_smooth > 0.5).astype(np.uint8)

# Step 4: Dilate the thresholded mask with 2 passes
mask_dilated = binary_dilation(mask_thres, iterations=dilation_iters).astype(np.uint8)

# Create a surface actor
roi_mask_actor = contour_from_roi(mask_dilated, affine=new_affine, opacity=roi_opacities[i], color=color)
roi_actors.append(roi_mask_actor)
return roi_actors




def _create_nifti_actor(
nifti_path: os.PathLike,
volume_idx: int | None = None,
Expand Down Expand Up @@ -347,10 +504,17 @@ def _create_colorbar_actor(
colorbar_height: float = 0.5,
colorbar_width: float = 0.1,
cmap: Colormap | None = None,
labels: bool = True,
labels: list[str] | bool = True,
) -> vtk.vtkScalarBarActor:
"""Create a colorbar actor for the scene."""

# if isinstance(labels, list):
# n_labels = len(labels)
# lut = vtk.vtkLookupTable()
# lut.SetNumberOfTableValues(n_labels)
# lut.Build()
# for i, label in enumerate(labels):

# Create a grayscale colormap (from black to white)
lut = vtk.vtkLookupTable()
lut.SetNumberOfTableValues(256) # Full grayscale (256 levels)
Expand Down