Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 243 additions & 1 deletion image_stitcher/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Annotated, Any, ClassVar, Literal, NamedTuple, Optional, Union
import numpy as np
import pandas as pd
import tifffile
from dask_image.imread import imread as dask_imread
from pydantic import AfterValidator, BaseModel, Field, computed_field, ConfigDict

Expand Down Expand Up @@ -242,6 +243,8 @@ class AcquisitionMetadata:
"""The index of the current field of view."""
t: int
"""The current timepoint."""
frame_idx: int = 0
"""The frame index within a multi-page TIFF file (0 for single-page files)."""

@property
def key(self) -> MetaKey:
Expand Down Expand Up @@ -299,7 +302,14 @@ def __post_init__(self) -> None:
self.init_timepoints()
self.init_acquisition_parameters()
self.init_pixel_size()
self.parse_acquisition_metadata()

# Choose parsing method based on detected format
if self.is_multipage_tiff_format():
logging.info("Detected multi-page TIFF format, using parse_multipage_tiff")
self.parse_multipage_tiff()
else:
logging.info("Detected individual file format, using parse_acquisition_metadata")
self.parse_acquisition_metadata()

def init_timepoints(self) -> None:
self.timepoints = [
Expand All @@ -309,6 +319,26 @@ def init_timepoints(self) -> None:
]
self.timepoints.sort()

def is_multipage_tiff_format(self) -> bool:
"""Detect if the acquisition uses multi-page TIFF format."""
if not self.timepoints:
return False

# Check the first timepoint for multi-page TIFF files
first_timepoint = self.timepoints[0]
image_folder = pathlib.Path(self.parent.input_folder) / str(first_timepoint)

if not image_folder.exists():
return False

# Look for files with "_stack" in the name
stack_files = [
f for f in image_folder.iterdir()
if f.suffix.lower() in (".tiff", ".tif") and "_stack" in f.name
]

return len(stack_files) > 0

def init_acquisition_parameters(self) -> None:
acquistion_params_path = os.path.join(
self.parent.input_folder, "acquisition parameters.json"
Expand Down Expand Up @@ -484,6 +514,218 @@ def parse_acquisition_metadata(self) -> None:
logging.info(f"{len(self.regions)} Regions: {self.regions}")
logging.info(f"Number of FOVs per region: {self.num_fovs_per_region}")

def parse_multipage_tiff(self) -> None:
"""Parse multi-page TIFF files - SAME LOGIC AS parse_acquisition_metadata."""
input_path = pathlib.Path(self.parent.input_folder)
self.acquisition_metadata = {}

max_z = 0
max_fov = 0

# Iterate over each timepoint - SAME AS ORIGINAL
for timepoint in self.timepoints:
image_folder = input_path / str(timepoint)
coordinates_path = image_folder / "coordinates.csv"

logging.info(f"Processing timepoint {timepoint}, image folder: {image_folder}")

try:
coordinates_df = pd.read_csv(coordinates_path)
except FileNotFoundError:
logging.warning(f"coordinates.csv not found for timepoint {timepoint}")
continue

# Find all multi-page TIFF files (replaces individual image files)
tiff_files = sorted([
f.resolve() for f in image_folder.iterdir()
if f.suffix.lower() in (".tiff", ".tif") and "_stack" in f.name
])

# Process each multi-page TIFF file
for tiff_file in tiff_files:
# Parse filename to get region and fov - SAME PARSING LOGIC
filename_parts = tiff_file.stem.split("_")
if len(filename_parts) >= 3 and filename_parts[-1] == "stack":
region = "_".join(filename_parts[:-2]) # Handle multi-word regions
fov = int(filename_parts[-2])

try:
with tifffile.TiffFile(tiff_file) as tif:
# Get all coordinate entries for this region/fov combination
fov_coords = coordinates_df[
(coordinates_df['region'] == region) &
(coordinates_df['fov'] == fov)
].sort_values('z_level')

if fov_coords.empty:
logging.warning(f"No coordinates for {tiff_file}")
continue

# Determine channels and z-levels from TIFF structure and coordinates
unique_z_levels = sorted(fov_coords['z_level'].unique())
num_z_levels = len(unique_z_levels)
num_pages = len(tif.pages)

# Infer number of channels
if num_z_levels > 0:
num_channels = num_pages // num_z_levels
else:
num_channels = 1

# Process each page in the TIFF - REPLACES FILE ITERATION
for page_idx, page in enumerate(tif.pages):
# Determine z_level and channel from page structure
if num_channels == 1:
# Single channel case
if page_idx < len(unique_z_levels):
z_level = unique_z_levels[page_idx]
else:
z_level = page_idx
channel = "BF" # Default channel name
else:
# Multiple channels case - cycle through z-levels for each channel
z_level = unique_z_levels[page_idx % num_z_levels]
channel_idx = page_idx // num_z_levels

# Try to extract channel name from TIFF metadata
try:
if hasattr(page, 'tags') and 'ImageDescription' in page.tags:
description = page.tags['ImageDescription'].value
if isinstance(description, str) and 'channel' in description.lower():
import re
channel_match = re.search(r'channel["\']?\s*:\s*["\']?([^"\',:}]+)', description, re.IGNORECASE)
if channel_match:
channel = channel_match.group(1).strip()
else:
channel = f"channel_{channel_idx}"
else:
channel = f"channel_{channel_idx}"
else:
channel = f"channel_{channel_idx}"
except:
channel = f"channel_{channel_idx}"

# Apply same channel name processing as original
channel = channel.replace("_", " ").replace("full ", "full_")

# Find coordinates - SAME LOGIC AS ORIGINAL
coord_rows = coordinates_df[
(coordinates_df["region"] == region)
& (coordinates_df["fov"] == fov)
& (coordinates_df["z_level"] == z_level)
]

if coord_rows.empty:
logging.warning(f"No coordinates for {tiff_file}, page {page_idx}")
continue

coord_row = coord_rows.iloc[0]

# Create metadata object - SAME AS ORIGINAL
meta = AcquisitionMetadata(
filepath=tiff_file, # Points to the multi-page TIFF
x=coord_row["x (mm)"],
y=coord_row["y (mm)"],
z=coord_row["z (um)"],
channel=channel,
z_level=z_level,
region=region,
fov_idx=fov,
t=timepoint,
frame_idx=page_idx
)

self.acquisition_metadata[meta.key] = meta

# Add region and channel names to the sets - SAME AS ORIGINAL
self.regions.append(region)
self.channel_names.append(channel)

# Update max_z and max_fov values - SAME AS ORIGINAL
max_z = max(max_z, z_level)
max_fov = max(max_fov, fov)

except Exception as e:
logging.warning(f"Error reading {tiff_file}: {e}")
continue

# After processing all timepoints, finalize the list of regions and channels - SAME AS ORIGINAL
self.regions = sorted(set(self.regions))
self.channel_names = sorted(set(self.channel_names))

# Calculate number of timepoints (t), Z levels, and FOVs per region - SAME AS ORIGINAL
self.num_t = len(self.timepoints)
self.num_z = max_z + 1
self.num_fovs_per_region = max_fov + 1

if not self.acquisition_metadata:
logging.warning("No acquisition metadata found")
return

# Set up image parameters based on the first image - SAME AS ORIGINAL
first_meta = list(self.acquisition_metadata.values())[0]

# Read the first page of the first TIFF to get image dimensions
try:
with tifffile.TiffFile(first_meta.filepath) as tif:
first_image = tif.pages[first_meta.frame_idx].asarray()
except:
logging.warning("Could not read first image for dimensions")
return

self.dtype = first_image.dtype
if len(first_image.shape) == 2:
self.input_height, self.input_width = first_image.shape
elif len(first_image.shape) == 3:
self.input_height, self.input_width = first_image.shape[:2]
else:
raise ValueError(f"Unexpected image shape: {first_image.shape}")

# Set up chunks - SAME AS ORIGINAL
self.chunks = (
1,
1,
1,
min(self.input_height, self.CHUNK_SIZE_LIMIT_PX),
min(self.input_width, self.CHUNK_SIZE_LIMIT_PX),
)

# Set up final monochrome channels - SAME AS ORIGINAL
self.monochrome_channels = []
for channel in self.channel_names:
(t, region, fov, z_level, _) = first_meta.key
channel_key = MetaKey(t, region, fov, z_level, channel) # Use MetaKey instead of tuple
if channel_key in self.acquisition_metadata:
# Read the channel image to check if it's RGB
try:
with tifffile.TiffFile(self.acquisition_metadata[channel_key].filepath) as tif:
channel_image = tif.pages[self.acquisition_metadata[channel_key].frame_idx].asarray()
except:
channel_image = first_image # Fallback

if len(channel_image.shape) == 3 and channel_image.shape[2] == 3:
channel = channel.split("_")[0]
self.monochrome_channels.extend(
[f"{channel}_R", f"{channel}_G", f"{channel}_B"]
)
else:
self.monochrome_channels.append(channel)
else:
self.monochrome_channels.append(channel)

self.num_c = len(self.monochrome_channels)
self.monochrome_colors = [
self.get_channel_color(name) for name in self.monochrome_channels
]

# Print out information about the dataset - SAME AS ORIGINAL
logging.info(f"Regions: {self.regions}, Channels: {self.channel_names}")
logging.info(f"FOV dimensions: {self.input_height}x{self.input_width}")
logging.info(f"{self.num_z} Z levels, {self.num_t} Time points")
logging.info(f"{self.num_c} Channels: {self.monochrome_channels}")
logging.info(f"{len(self.regions)} Regions: {self.regions}")
logging.info(f"Number of FOVs per region: {self.num_fovs_per_region}")

@staticmethod
def get_channel_color(channel_name: str) -> int:
"""Compute the color for display of a given channel name."""
Expand Down
15 changes: 13 additions & 2 deletions image_stitcher/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def compute_mip(tiles: list[np.ndarray]) -> np.ndarray:
else:
raise ValueError(f"Unexpected tile shape: {tiles[0].shape}")

def load_image(self, tile_info) -> np.ndarray:
"""Load an image from file, handling both single files and multi-page TIFF files."""
if hasattr(tile_info, 'frame_idx') and tile_info.frame_idx > 0:
# Multi-page TIFF file
import tifffile
with tifffile.TiffFile(tile_info.filepath) as tif:
return tif.pages[tile_info.frame_idx].asarray()
else:
# Single file
return skimage.io.imread(tile_info.filepath)

def create_output_array(
self, timepoint: int, region: str, num_z_layers: int
) -> AnyArray:
Expand Down Expand Up @@ -329,7 +340,7 @@ def stitch_region(self, timepoint: int, region: str) -> AnyArray:
# The order of tiles does not matter for compute_mip
tiles = []
for z_level, tile_info in z_tiles:
tile = skimage.io.imread(tile_info.filepath)
tile = self.load_image(tile_info)
tiles.append(tile)

# Compute MIP
Expand All @@ -353,7 +364,7 @@ def stitch_region(self, timepoint: int, region: str) -> AnyArray:
# Original processing logic for non-MIP case
for key, tile_info in self.metadata.items():
t, _, _, z_level, channel = key
tile = skimage.io.imread(tile_info.filepath)
tile = self.load_image(tile_info)

x_pixel = int(
(tile_info.x - x_min) * 1000 / self.computed_parameters.pixel_size_um
Expand Down
66 changes: 57 additions & 9 deletions image_stitcher/stitcher_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,68 @@ def onInputDirectoryDropped(self, path: str) -> None:
return

self.inputDirectory = str(acquisition_dir)
# Visually mark the drop area
self.inputDirDropArea.setText(f"Loaded: {acquisition_dir.name}")
self.inputDirDropArea.setStyleSheet("""
QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}
""")
self.probeDatasetForZLayers()

# Detect and display the acquisition format
try:
temp_params = StitchingParameters(
input_folder=self.inputDirectory,
output_format=OutputFormat.ome_zarr,
scan_pattern=ScanPattern.unidirectional,
)
temp_stitcher = Stitcher(temp_params)

# Check if it's multi-page TIFF format
if temp_stitcher.computed_parameters.is_multipage_tiff_format():
format_info = " (Multi-page TIFF)"
else:
format_info = " (Individual files)"

# Visually mark the drop area
self.inputDirDropArea.setText(f"Loaded: {acquisition_dir.name}{format_info}")
self.inputDirDropArea.setStyleSheet("""
QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}
""")
self.probeDatasetForZLayers()

except Exception as e:
logging.warning(f"Could not detect acquisition format: {e}")
# Fallback to original behavior
self.inputDirDropArea.setText(f"Loaded: {acquisition_dir.name}")
self.inputDirDropArea.setStyleSheet("""
QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}
""")
self.probeDatasetForZLayers()

def selectInputDirectory(self) -> None: # Kept for now, can be removed if button is fully replaced
dir = QFileDialog.getExistingDirectory(self, "Select Input Image Folder")
if dir:
self.inputDirectory = dir
self.inputDirDropArea.setText(f"Loaded: {pathlib.Path(dir).name}")
self.inputDirDropArea.setStyleSheet("""QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}""")
self.probeDatasetForZLayers()

# Detect and display the acquisition format
try:
temp_params = StitchingParameters(
input_folder=self.inputDirectory,
output_format=OutputFormat.ome_zarr,
scan_pattern=ScanPattern.unidirectional,
)
temp_stitcher = Stitcher(temp_params)

# Check if it's multi-page TIFF format
if temp_stitcher.computed_parameters.is_multipage_tiff_format():
format_info = " (Multi-page TIFF)"
else:
format_info = " (Individual files)"

self.inputDirDropArea.setText(f"Loaded: {pathlib.Path(dir).name}{format_info}")
self.inputDirDropArea.setStyleSheet("""QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}""")
self.probeDatasetForZLayers()

except Exception as e:
logging.warning(f"Could not detect acquisition format: {e}")
# Fallback to original behavior
self.inputDirDropArea.setText(f"Loaded: {pathlib.Path(dir).name}")
self.inputDirDropArea.setStyleSheet("""QLabel {border: 2px solid green; border-radius: 5px; background-color: #e0ffe0;}""")
self.probeDatasetForZLayers()

def probeDatasetForZLayers(self) -> None:
if not self.inputDirectory:
Expand Down