Skip to content

Commit

Permalink
feat: phase correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
kgabor committed May 17, 2024
2 parents 89a8933 + 596d066 commit 46912d0
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 3 deletions.
47 changes: 47 additions & 0 deletions src/aind_exaspim_pipeline_utils/imagej_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ class ImagejMacros:
# strings can be added " " + " "
# within strings, [ ] can be used for spaces, for file names? TBC

MACRO_PHASE_CORRELATION = """
run("Memory & Threads...", "parallel={parallel:d}");
run("Calculate pairwise shifts ...",
" select={process_xml}" +
" process_angle=[All angles] process_channel=[All channels]" +
" process_illumination=[All illuminations] process_tile=[All tiles] process_timepoint=[All Timepoints]" +
" method=[Phase Correlation] show_expert_grouping_options" +
" how_to_treat_timepoints=group how_to_treat_channels=group how_to_treat_illuminations=group" +
" how_to_treat_angles=group how_to_treat_tiles=compare" +
" channels=[Average Channels]" +
" downsample_in_x={downsample} downsample_in_y={downsample} downsample_in_z={downsample}");
run("Filter pairwise shifts ...",
"select={process_xml} filter_by_link_quality min_r={min_correlation}" +
" max_r=1 max_shift_in_x={max_shift_in_x} max_shift_in_y={max_shift_in_y}" +
" max_shift_in_z={max_shift_in_z}");
// do global optimization
run("Optimize globally and apply shifts ...",
"select={process_xml} process_angle=[All angles] process_channel=[All channels] " +
"process_illumination=[All illuminations] process_tile=[All tiles]" +
" process_timepoint=[All Timepoints]" +
" relative=2.500 absolute=3.500 global_optimization_strategy=" +
"[Two-Round using Metadata to align unconnected Tiles] fix_group_0-0,");
eval("script", "System.exit(0);");
"""

MACRO_IP_DET = """
run("Memory & Threads...", "parallel={parallel:d}");
run("Detect Interest Points for Registration",
Expand Down Expand Up @@ -157,3 +185,22 @@ def get_macro_ip_reg(P: Dict[str, Any]) -> str:
"select_reference_views"
] = " select_reference_views=[ViewSetupId:{:d} Timepoint:0]".format(P["map_back_reference_view"])
return ImagejMacros.MACRO_IP_REG.format(**fparams)

@staticmethod
def get_macro_phase_correlation(P: Dict[str, Any]) -> str:
""" Get a parameter formatted phase correlation macro.
Parameters
----------
P : `dict`
Parameter dictionary for macro formatting.
Note: Will already have
process_xml: path to xml to process
downsample: pyramid level to use
min_correlation: minimum correlation to consider
max_shift_in_x: maximum shift in x
max_shift_in_y: maximum shift in y
max_shift_in_z: maximum shift in z
"""
fparams = dict(P)
return ImagejMacros.MACRO_PHASE_CORRELATION.format(**fparams)
60 changes: 59 additions & 1 deletion src/aind_exaspim_pipeline_utils/imagej_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,37 @@
# from aind_data_schema import DataProcess
# from aind_data_schema.processing import ProcessName


from . import __version__
from .imagej_macros import ImagejMacros
from .exaspim_manifest import get_capsule_manifest, write_process_metadata, ExaspimProcessingPipeline
from .qc import bigstitcher_log_edge_analysis
from .qc.create_ng_link import create_ng_link


class PhaseCorrelationSchema(argschema.ArgSchema): # pragma: no cover

"""Adjustable parameters for phase correlation."""

downsample = fld.Int(
required=True,
metadata={"description": "Downsampling factor. Use the one that is available in the dataset."},
)
min_correlation = fld.Float(
load_default=0.6, metadata={"description": "Minimum correlation value for phase correlation."},
)
max_shift_in_x = fld.Int(
load_default=0, metadata={"description": "Maximum displacement in x direction."},
)
max_shift_in_y = fld.Int(
load_default=0, metadata={"description": "Maximum displacement in y direction."},
)
max_shift_in_z = fld.Int(
load_default=0, metadata={"description": "Maximum displacement in z direction."},
)


class IPDetectionSchema(argschema.ArgSchema): # pragma: no cover

"""Adjustable parameters to detect IP."""

downsample = fld.Int(
Expand Down Expand Up @@ -152,6 +174,9 @@ class ImageJWrapperSchema(argschema.ArgSchema): # pragma: no cover
validate=mm.validate.Range(min=1, max=128),
)
dataset_xml = fld.String(required=True, metadata={"description": "Input xml dataset definition"})
phase_correlation_params = fld.Nested(
PhaseCorrelationSchema, required=False, metadata={"description": "Phase correlation parameters"}
)
do_detection = fld.Boolean(required=True, metadata={"description": "Do interest point detection?"})
ip_detection_params = fld.Nested(
IPDetectionSchema, required=False, metadata={"description": "Interest point detection parameters"}
Expand All @@ -166,6 +191,10 @@ class ImageJWrapperSchema(argschema.ArgSchema): # pragma: no cover
metadata={"description": "Registration parameters (do_registrations==True only)"},
many=True,
)
do_phase_correlation = fld.Boolean(
required=False,
metadata={"description": "Do phase correlation for affine only?"},
)


def print_output(data, f, stderr=False) -> None: # pragma: no cover
Expand Down Expand Up @@ -279,12 +308,15 @@ def get_auto_parameters(args: Dict) -> Dict: # pragma: no cover

process_xml = "../results/bigstitcher.xml"
macro_ip_det = "../results/macro_ip_det.ijm"
macro_phase_corr = "../results/macro_phase_corr.ijm"

return {
"process_xml": process_xml,
# Do not use, this is the whole VM at the moment, not what is available for the capsule
"auto_ncpu": ncpu,
"auto_memgb": mem_GB,
"macro_ip_det": macro_ip_det,
"macro_phase_corr": macro_phase_corr,
}


Expand All @@ -307,6 +339,32 @@ def main(): # pragma: no cover
logger.info("Copying input xml %s -> %s", args["dataset_xml"], args["process_xml"])
shutil.copy(args["dataset_xml"], args["process_xml"])

if args['do_phase_correlation']:
# logger.info("Creating macro for phase correlation", args["do_phase_correlation"])
det_params = dict(args["phase_correlation_params"])
det_params["process_xml"] = args["process_xml"]
det_params["parallel"] = args["parallel"]

# write phase correlation macro
with open(args["macro_phase_corr"], "w") as f:
f.write(ImagejMacros.get_macro_phase_correlation(det_params))
# run phase correlation
r = wrapper_cmd_run(
[
"ImageJ",
"-Dimagej.updater.disableAutocheck=true",
"--headless",
"--memory",
"{memgb}G".format(**args),
"--console",
"--run",
args["macro_phase_corr"]
],
logger,
)
if r != 0:
raise RuntimeError("Phase Correlation command failed.")

if args["do_detection"]:
det_params = dict(args["ip_detection_params"])
det_params["parallel"] = args["parallel"]
Expand Down
35 changes: 33 additions & 2 deletions tests/test_imagej_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Tests for ImageJ wrapper functions and macro creator class"""
import argschema
import contextlib
import io
import logging
import unittest
from unittest import mock

import argschema

from aind_exaspim_pipeline_utils.imagej_macros import ImagejMacros
from aind_exaspim_pipeline_utils.imagej_wrapper import (
ImageJWrapperSchema,
Expand Down Expand Up @@ -92,12 +91,34 @@ def setUp(self):
}
],
}
example_params_phase_correlation_default = {
"session_id": "2023-02-22",
"memgb": 55,
"parallel": 8,
"dataset_xml": "test_dataset.xml",
"do_phase_correlation": True,
"do_detection": False,
"do_registrations": False,
"phase_correlation_params": {
"downsample": 2,
"min_correlation": 0.6,
"max_shift_in_x": 10,
"max_shift_in_y": 10,
"max_shift_in_z": 10,
},
}
parser = argschema.ArgSchemaParser(
schema_type=ImageJWrapperSchema, input_data=example_params_default, args=[]
)
self.args = parser.args

phase_parser = argschema.ArgSchemaParser(
schema_type=ImageJWrapperSchema, input_data=example_params_phase_correlation_default, args=[]
)
self.phase_args = phase_parser.args

def testMacroIPDet(self):

"""Test IP Detection macro"""

det_params = dict(self.args["ip_detection_params"])
Expand Down Expand Up @@ -129,3 +150,13 @@ def testMacroIPReg(self):
reg_params["map_back_views_choice"] = "selected_translation"
m = ImagejMacros.get_macro_ip_reg(reg_params)
self.assertRegex(m, "ViewSetupId:5")

def testMacroPhaseCorrelation(self):
"""Test Phase Correlation macro"""
phase_params = dict(self.phase_args["phase_correlation_params"])
phase_params["process_xml"] = self.phase_args["dataset_xml"]
phase_params["parallel"] = self.args["parallel"]

m = ImagejMacros.get_macro_phase_correlation(phase_params)
self.assertRegex(m, "select=test_dataset.xml")
self.assertNotRegex(m, "viewsetupid_")

0 comments on commit 46912d0

Please sign in to comment.