Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion image_stitcher/flatfield_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,71 @@
from .parameters import StitchingComputedParameters


def save_flatfield_correction(
flatfields: dict[int, np.ndarray],
computed_params: StitchingComputedParameters,
output_dir: Path,
) -> Path:
"""
Save computed flatfield correction data to the specified directory.

Creates a flatfield_manifest.json file and saves individual .npy files
for each channel in the same format expected by load_flatfield_correction.

Args:
flatfields: Dictionary mapping channel-index to flatfield numpy arrays.
computed_params: Computed stitching parameters containing channel information.
output_dir: Directory where flatfield files should be saved.

Returns:
Path to the created manifest file.
"""
# Ensure output directory exists
output_dir.mkdir(parents=True, exist_ok=True)

# Create the files dictionary for the manifest
files_dict = {}

for channel_idx, flatfield_array in flatfields.items():
# Get the channel name from the computed parameters
if channel_idx < len(computed_params.monochrome_channels):
channel_name = computed_params.monochrome_channels[channel_idx]
else:
logging.warning(f"Channel index {channel_idx} is out of range for monochrome_channels")
continue

# Create filename for this channel
npy_filename = f"{channel_name}_flatfield.npy"
npy_path = output_dir / npy_filename

# Save the flatfield array
try:
np.save(npy_path, flatfield_array)
logging.info(f"Saved flatfield for channel '{channel_name}' to {npy_path}")

# Add to manifest using channel name as key
files_dict[channel_name] = npy_filename

except Exception as e:
logging.error(f"Failed to save flatfield for channel '{channel_name}': {e}")
continue

# Create and save the manifest file
manifest_content = {"files": files_dict}
manifest_path = output_dir / "flatfield_manifest.json"

try:
with open(manifest_path, "w") as f:
json.dump(manifest_content, f, indent=2)
logging.info(f"Created flatfield manifest at {manifest_path}")

return manifest_path

except Exception as e:
logging.error(f"Failed to create flatfield manifest: {e}")
raise


def load_flatfield_correction(
manifest_filepath: Path, computed_params: StitchingComputedParameters
) -> dict[int, np.ndarray]:
Expand Down Expand Up @@ -89,4 +154,4 @@ def load_flatfield_correction(
logging.info(f"Loaded flatfields from {manifest_filepath!r} for channel-indices: {list(flatfields.keys())}")
else:
logging.warning(f"No flatfields were successfully loaded from manifest {manifest_filepath!r}.")
return flatfields
return flatfields
21 changes: 20 additions & 1 deletion image_stitcher/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,25 @@ def run(self) -> None:
self.callbacks.getting_flatfields,
)
)

# Save the computed flatfields to the acquisition folder
if self.computed_parameters.flatfields:
from .flatfield_utils import save_flatfield_correction

acquisition_folder = pathlib.Path(self.params.input_folder)
flatfield_dir = acquisition_folder / "flatfields"

try:
manifest_path = save_flatfield_correction(
self.computed_parameters.flatfields,
self.computed_parameters,
flatfield_dir,
)
logging.info(f"Saved computed flatfields to {flatfield_dir}")
logging.info(f"Flatfield manifest created at {manifest_path}")
except Exception as e:
logging.error(f"Failed to save computed flatfields: {e}")
# Continue processing even if saving fails

# Validate loaded/computed flatfields
if self.computed_parameters.flatfields:
Expand Down Expand Up @@ -739,4 +758,4 @@ def stitch_all_regions_and_timepoints(self) -> None:
# Loop over timepoints and regions
for t_idx, t in enumerate(self.computed_parameters.timepoints):
for r_idx, region in enumerate(self.computed_parameters.regions):
self.stitch_region(t, region)
self.stitch_region(t, region)