Skip to content

Commit

Permalink
Add spark processing entries to the manifest. Add java postprocessing…
Browse files Browse the repository at this point in the history
… script.

 * SparkInterestPointDetections to the processing manifest.
 * Add solver and GeometricDescriptorMatching config sections to manifest generation.
 * Fix some of the json defaults in particular not null for globalOptType.
 * Add translation and affine matching and solver configs to the manifest.
 * Add java output postprocessing and automatic uploading to S3.
  • Loading branch information
kgabor committed Apr 24, 2024
1 parent b66759b commit dfa406e
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ zarr_multiscale_converter = "aind_exaspim_pipeline_utils.n5tozarr.n5tozarr_da:za
create_example_manifest = "aind_exaspim_pipeline_utils.exaspim_manifest:create_example_manifest"
bigstitcher_log_edge_analysis = "aind_exaspim_pipeline_utils.qc.bigstitcher_log_edge_analysis:main"
run_trigger_capsule = "aind_exaspim_pipeline_utils.trigger.capsule:capsule_main"
java_detreg_postprocess = "aind_exaspim_pipeline_utils.java_utils:java_detreg_postprocess_main"

[tool.setuptools.dynamic]
version = {attr = "aind_exaspim_pipeline_utils.__version__"}
Expand Down
187 changes: 185 additions & 2 deletions src/aind_exaspim_pipeline_utils/exaspim_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import os
from datetime import datetime
from typing import Optional, Tuple, List
from typing import Optional, Tuple, List, Union

from aind_data_schema import DataProcess
from aind_data_schema.base import AindModel
Expand Down Expand Up @@ -221,6 +221,177 @@ def validate_regularize_with_choice(cls, v: str) -> str:
)


class SparkInterestPointDetections(AindModel): # pragma: no cover
"""Interest point detection parameters"""

label: str = Field("beads", title="Label of the interest points")
sigma: float = Field(4.0, title="sigma for segmentation, e.g. 1.8")
threshold: float = Field(0.0015, title="threshold for segmentation, e.g. 0.008")
overlappingOnly: bool = Field(True, title="Find overlapping interest points only")
storeIntensities: bool = Field(True, title="Store intensities")
prefetch: bool = Field(True, title="Prefetch")
minIntensity: int = Field(0, title="Minimal intensity value")
maxIntensity: int = Field(2000, title="Maximal intensity value")
dsxy: int = Field(4, title="Downsampling factor for x and y")
dsz: int = Field(4, title="Downsampling factor for z")
blockSizeString: str = Field("1024,1024,1024", title="Block size string")
type: str = Field("MAX", title="the type of interestpoints to find, MIN, MAX or BOTH (default: MAX)")
localization: str = Field(
"QUADRATIC", title="Subpixel localization method, NONE or QUADRATIC (default: QUADRATIC)"
)


class SparkGeometricDescriptorMatching(AindModel): # pragma: no cover
"""Geometric descriptor matching parameters"""

label: str = Field("beads", title="Label")
registrationMethod: str = Field(
"PRECISE_TRANSLATION",
title="the matching method; FAST_ROTATION, FAST_TRANSLATION, PRECISE_TRANSLATION or ICP",
)
significance: float = Field(
3.0,
title="how much better the first match between two descriptors has "
"to be compareed to the second best one (default: 3.0)",
)
redundancy: int = Field(1, title="the redundancy of the local descriptor (default: 1)")
numNeighbors: int = Field(
3,
title="the number of neighboring points used to build the local descriptor,"
" only supported by PRECISE_TRANSLATION (default: 3)",
)
clearCorrespondences: bool = Field(False, title="Clear Correspondences")
interestpointsForReg: str = Field(
"OVERLAPPING_ONLY",
title="which interest points to use for pairwise registrations, "
"use OVERLAPPING_ONLY or ALL points (default: ALL)",
)
viewReg: str = Field(
"OVERLAPPING_ONLY",
title="which views to register with each other, compare OVERLAPPING_ONLY "
"or ALL_AGAINST_ALL (default: OVERLAPPING_ONLY)",
)
interestPointMergeDistance: float = Field(5.0, title="Interest Point Merge Distance")
groupIllums: bool = Field(False, title="Group Illuminations")
groupChannels: bool = Field(False, title="Group Channels")
groupTiles: bool = Field(False, title="Group Tiles")
splitTimepoints: bool = Field(False, title="Split Timepoints")
ransacIterations: Optional[int] = Field(
None, title="max number of ransac iterations (default: 10,000 for descriptors, 200 for ICP)"
)
ransacMaxError: Optional[float] = Field(
None, title="ransac max error in pixels (default: 5.0 for descriptors, 2.5 for ICP)"
)
ransacMinInlierRatio: float = Field(0.1, title="RANSAC Minimum Inlier Ratio")
ransacMinInlierFactor: float = Field(
3.0,
title="ransac min inlier factor, i.e. how many time the minimal number of matches need to found, "
"e.g. affine needs 4 matches, 3x means at least 12 matches required (default: 3.0)",
)
icpMaxError: float = Field(5.0, title="ICP max error in pixels (default: 5.0)")
icpIterations: int = Field(200, title="max number of ICP iterations (default: 200)")
icpUseRANSAC: bool = Field(False, title="ICP Use RANSAC")

registrationTP: str = Field(
"TIMEPOINTS_INDIVIDUALLY",
title="time series registration type; TIMEPOINTS_INDIVIDUALLY (i.e. no registration across time), "
"TO_REFERENCE_TIMEPOINT, ALL_TO_ALL or ALL_TO_ALL_WITH_RANGE "
"(default: TIMEPOINTS_INDIVIDUALLY)",
)
referenceTP: Optional[int] = Field(
None, title="the reference timepoint if timepointAlign == REFERENCE (default: first timepoint)"
)
rangeTP: int = Field(
5, title="the range of timepoints if timepointAlign == ALL_TO_ALL_RANGE (default: 5)"
)
transformationModel: str = Field(
"AFFINE", title="which transformation model to use; TRANSLATION, RIGID or AFFINE (default: AFFINE)"
)
regularizationModel: str = Field(
"RIGID",
title="which regularization model to use; NONE, IDENTITY, "
"TRANSLATION, RIGID or AFFINE (default: RIGID)",
)
regularizationLambda: float = Field(0.1, title="lamdba to use for regularization model (default: 0.1)")


class Solver(AindModel): # pragma: no cover
"""Solver parameters"""

sourcePoints: str = Field(
"IP", title="which source to use for the solve, IP (interest points) or STITCHING"
)
groupIllums: Optional[bool] = Field(
None,
title="group all illumination directions that belong to the same angle/channel/tile/timepoint "
"together as one view, e.g. to stitch illums as one "
"(default: false for IP, true for stitching)",
)
groupChannels: Optional[bool] = Field(
None,
title="group all channels that belong to the same angle/illumination/tile/timepoint together "
"as one view, e.g. to stitch channels as one (default: false for IP, true for stitching)",
)
groupTiles: Optional[bool] = Field(
None,
title="group all tiles that belong to the same angle/channel/illumination/timepoint together "
"as one view, e.g. to align across angles (default: false)",
)
splitTimepoints: Optional[bool] = Field(
None,
title="group all angles/channels/illums/tiles that belong to the same timepoint as one View, "
"e.g. for stabilization across time (default: false)",
)
label: Optional[str] = Field(
"beads", title="label of the interest points used for solve if using interest points (e.g. beads)"
)
globalOptType: str = Field(
"TWO_ROUND_ITERATIVE",
title="global optimization method; ONE_ROUND_SIMPLE, ONE_ROUND_ITERATIVE, TWO_ROUND_SIMPLE or "
"TWO_ROUND_ITERATIVE. Two round handles unconnected tiles, iterative handles wrong links "
"(default: ONE_ROUND_SIMPLE)",
)
relativeThreshold: float = Field(
0.0,
title="relative error threshold for iterative solvers, how many times worse than the average error "
"a link needs to be (default: 3.5)",
)
absoluteThreshold: float = Field(
0.0, title="absoluted error threshold for iterative solver to drop a link in pixels (default: 7.0)"
)
maxError: float = Field(5.0, title="max error for the solve (default: 5.0)")
maxIterations: int = Field(10000, title="max number of iterations for solve (default: 10,000)")
maxPlateauwidth: int = Field(200, title="max plateau witdth for solve (default: 200)")
disableFixedViews: bool = Field(False, title="disable fixing of views (see --fixedViews)")
fixedViews: Optional[List[str]] = Field(
["0,7"],
title="define a list of (or a single) fixed view ids (time point, view setup), e.g. -fv '0,0' "
"-fv '0,1' (default: first view id)",
)

registrationTP: str = Field(
"TIMEPOINTS_INDIVIDUALLY",
title="time series registration type; TIMEPOINTS_INDIVIDUALLY (i.e. no registration across time), "
"TO_REFERENCE_TIMEPOINT, ALL_TO_ALL or ALL_TO_ALL_WITH_RANGE "
"(default: TIMEPOINTS_INDIVIDUALLY)",
)
referenceTP: Optional[int] = Field(
None, title="the reference timepoint if timepointAlign == REFERENCE (default: first timepoint)"
)
rangeTP: int = Field(
5, title="the range of timepoints if timepointAlign == ALL_TO_ALL_RANGE (default: 5)"
)
transformationModel: str = Field(
"AFFINE", title="which transformation model to use; TRANSLATION, RIGID or AFFINE (default: AFFINE)"
)
regularizationModel: str = Field(
"RIGID",
title="which regularization model to use; NONE, IDENTITY, "
"TRANSLATION, RIGID or AFFINE (default: RIGID)",
)
regularizationLambda: float = Field(0.1, title="lamdba to use for regularization model (default: 0.1)")


class XMLCreationParameters(AindModel): # pragma: no cover
"""XML converter capsule parameters."""

Expand All @@ -229,7 +400,7 @@ class XMLCreationParameters(AindModel): # pragma: no cover
input_uri: Optional[str] = Field(
None,
title="Input Zarr group dataset path. This is the dataset the alignment is running on."
"Must be the aind-open-data s3:// path without the SPIM.ome.zarr suffix",
"Must be the aind-open-data s3:// path without the SPIM.ome.zarr suffix",
)


Expand Down Expand Up @@ -257,6 +428,17 @@ class ExaspimProcessingPipeline(AindModel): # pragma: no cover
"the directory containing all data and metadata",
title="Name",
)
spark_ip_detections: Union[
None, SparkInterestPointDetections, List[SparkInterestPointDetections]
] = Field(None, title="Spark interest point detection")
spark_geometric_descriptor_matching_tr: Union[None, SparkGeometricDescriptorMatching] = Field(
None, title="Spark geometric descriptor matching"
)
solver_tr: Union[None, Solver] = Field(None, title="Solver parameters")
spark_geometric_descriptor_matching_aff: Union[None, SparkGeometricDescriptorMatching] = Field(
None, title="Spark geometric descriptor matching"
)
solver_aff: Union[None, Solver] = Field(None, title="Solver parameters")

xml_creation: XMLCreationParameters = Field(None, title="XML creation")
ip_detection: IPDetectionParameters = Field(None, title="Interest point detection")
Expand Down Expand Up @@ -330,6 +512,7 @@ def get_capsule_metadata() -> dict: # pragma: no cover
# return Metadata.parse_obj(json_data)
return json_data


# TODO: We do not yet use an accumulative metadata file with multiple data_process entries.
# def append_process_entries_to_metadata(
# dataset_metadata: Metadata, processes: Iterable[DataProcess]
Expand Down
82 changes: 82 additions & 0 deletions src/aind_exaspim_pipeline_utils/java_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Java IpdetRegSolv capsule postprocessing actions"""

import logging
import os

from .exaspim_manifest import get_capsule_manifest
from .imagej_wrapper import (
create_edge_connectivity_report,
upload_alignment_results,
fmt_uri,
get_auto_parameters,
)
from .qc.create_ng_link import create_ng_link


def java_detreg_postprocess_main(): # pragma: no cover
"""Entry point for java capsule postprocessing."""

logging.basicConfig(format="%(asctime)s %(levelname)-7s %(message)s")

logger = logging.getLogger()

pipeline_manifest = get_capsule_manifest()

args = {
"dataset_xml": "../data/manifest/dataset.xml",
"session_id": pipeline_manifest.pipeline_suffix,
"log_level": logging.DEBUG,
"name": pipeline_manifest.name,
"subject_id": pipeline_manifest.subject_id,
}
# "input_uri" and "output_uri" are formatted to have trailing slashes
if pipeline_manifest.ip_registrations:
args["output_uri"] = fmt_uri(pipeline_manifest.ip_registrations[-1].IJwrap.output_uri)
args["input_uri"] = fmt_uri(pipeline_manifest.ip_registrations[-1].IJwrap.input_uri)
else:
args["output_uri"] = fmt_uri(pipeline_manifest.ip_detection.IJwrap.output_uri)
args["input_uri"] = fmt_uri(pipeline_manifest.ip_detection.IJwrap.input_uri)

logger.setLevel(logging.DEBUG)
logging.getLogger("botocore").setLevel(logging.INFO)
logging.getLogger("urllib3").setLevel(logging.INFO)
logging.getLogger("s3fs").setLevel(logging.INFO)
logger.info(f"This is result postprocessing for session {args['session_id']}")
args.update(get_auto_parameters(args))

# process_meta = get_imagej_wrapper_metadata(
# {
# "ip_detection": pipeline_manifest.ip_detection,
# "ip_registrations": pipeline_manifest.ip_registrations,
# },
# input_location=args["input_uri"],
# output_location=args["output_uri"],
# )
# write_process_metadata(process_meta, prefix="ipreg")

# Create ng links for all the registrations
nglinks = []
for i in range(3):
suffix = f"~{i}" if i > 0 else ""
xml_path = args["process_xml"] + suffix
if os.path.exists(xml_path):
logger.info("Creating ng link for registration %d (xml order)", i)
thelink = create_ng_link(
"{}SPIM.ome.zarr".format(args["input_uri"]),
args["output_uri"].rstrip("/"),
xml_path=xml_path,
output_json=f"../results/ng/process_output_{i}.json",
)
if thelink:
nglinks.append(thelink)
else:
logger.warning("Registration %d xml file %s does not exist. Skipping.", i, xml_path)
# if process_meta.outputs is None:
# process_meta.outputs = {}
# if nglinks:
# process_meta.outputs["ng_links"] = nglinks
logger.info("Creating edge connectivity report")
create_edge_connectivity_report(2)

logger.info("Uploading capsule results to {}".format(args["output_uri"]))
upload_alignment_results(args)
25 changes: 25 additions & 0 deletions src/aind_exaspim_pipeline_utils/trigger/capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
ExaspimProcessingPipeline,
N5toZarrParameters,
ZarrMultiscaleParameters,
SparkInterestPointDetections,
SparkGeometricDescriptorMatching,
Solver,
)

logging.basicConfig(format="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M")
Expand Down Expand Up @@ -458,6 +461,7 @@ def create_exaspim_manifest(args, metadata): # pragma: no cover
ip_limitation_choice="brightest",
maximum_number_of_detections=150000,
)

ip_reg_translation: IPRegistrationParameters = IPRegistrationParameters(
# dataset_xml=capsule_xml_path,
IJwrap=def_ij_wrapper_parameters,
Expand Down Expand Up @@ -506,6 +510,27 @@ def create_exaspim_manifest(args, metadata): # pragma: no cover
name=metadata["data_description"].get("name"),
xml_creation=xml_creation,
ip_detection=def_ip_detection_parameters,
spark_ip_detections=SparkInterestPointDetections(overlappingOnly=True),
spark_geometric_descriptor_matching_tr=SparkGeometricDescriptorMatching(
clearCorrespondences=True,
transformationModel="TRANSLATION",
regularizationModel="NONE",
),
solver_tr=Solver(
transformationModel="TRANSLATION",
regularizationModel="NONE",
fixedViews=["0,7"],
),
spark_geometric_descriptor_matching_aff=SparkGeometricDescriptorMatching(
clearCorrespondences=True,
transformationModel="AFFINE",
regularizationModel="RIGID",
),
solver_aff=Solver(
transformationModel="AFFINE",
regularizationModel="RIGID",
fixedViews=["0,7"],
),
ip_registrations=[ip_reg_translation, ip_reg_affine],
n5_to_zarr=n5_to_zarr,
zarr_multiscale=zarr_multiscale,
Expand Down

0 comments on commit dfa406e

Please sign in to comment.