Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat][METEOR-1128] TDCT Odemis Wrapper #2999

Merged
merged 1 commit into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions src/odemis/acq/align/tdct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""
Created on 8 Jan 2025

Copyright © 2025 Delmic

This file is part of Odemis.

Odemis is free software: you can redistribute it and/or modify it under the
terms of the GNU General Public License version 2 as published by the Free
Software Foundation.

Odemis is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
Odemis. If not, see http://www.gnu.org/licenses/.
"""

import os
import logging
import sys
from typing import Dict, List, Tuple, Any

import numpy
import yaml

from odemis import model

# install from: https://github.com/patrickcleeve2/3DCT/blob/refactor
sys.path.append(f"{os.path.expanduser('~')}/development/3DCT")

from tdct.correlation_v2 import run_correlation
from tdct.util import multi_channel_get_z_guass

def _convert_das_to_numpy_stack(das: List[model.DataArray]) -> numpy.ndarray:
"""Convert a list of DataArrays to a numpy stack.
Channels are stored as list dimensions, rather than data array dimensions.
Therefore, multi-channel images are stored as list[CTZYX, CTZYX, ...] where C=1
and length of list is number of channels.
:param das: list of meteor data arrays (supports 5D CTZYX, 3D ZYX, 2D YX arrays)
:return the data arrays reshapes to a 4D numpy array (CZYX)"""
arr = []
for da in das:
if isinstance(da, model.DataArrayShadow):
da = da.getData()

# convert to 3D ZYX
if da.ndim == 5:
if da.shape[0] != 1 or da.shape[1] != 1:
logging.warning(f"Only the first channel and time dimension will be used for 5D data array: {da.shape}")
# remove the channel, time dimensions
da = da[0, 0, :, :, :]
elif da.ndim == 2:
# expand to 3D ZYX
da = numpy.expand_dims(da, axis=0)

assert da.ndim == 3, f"DataArray must be 3D ZYX, but is {da.shape}"
arr.append(da)

return numpy.stack(arr, axis=0)

def get_optimized_z_gauss(das: List[model.DataArray], x: int, y: int, z: int, show: bool = False) -> float:
"""Get the best fitting z-coordinate for the given x, y coordinates. Supports multi-channel images.
:param das: the data arrays (CTZYX, ZYX, or YX), all arrays must have the same shape
:param x: the x-coordinate
:param y: the y-coordinate
:param z: the z-coordinate (initial guess)
Copy link
Contributor

@tmoerkerken tmoerkerken Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this argument is only present in case of failure, in which case it will be returned without manipulation. To me it makes more sense to not make it an arg/return but just throw the error (or return None, since no optimal z was found) and handle the failure case outside of this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point but disagree, almost always the handling will be to use the initial value. If it becomes an issue it can be changed later

:param show: show the plot for debugging
:return: the z-coordinate (optimized)"""
prev_z = z
prev_x, prev_y = x, y

# fm_image must be 4D np.ndarray with shape (channels, z, y, x)
fm_image = _convert_das_to_numpy_stack(das)

try:
# getzGauss can fail, so we need to catch the exception
zval, z, _ = multi_channel_get_z_guass(image=fm_image, x=x, y=y, show=show)
logging.info(f"Using Z-Gauss optimisation: {z}, previous z: {prev_z}")

except RuntimeError as e:
logging.warning(f"Error in z-gauss optimisation: {e}, using previous z: {prev_z}")
z = prev_z
x, y = prev_x, prev_y

return z

def run_tdct_correlation(fib_coords: numpy.ndarray,
fm_coords: numpy.ndarray,
poi_coords: numpy.ndarray,
fib_image: model.DataArray,
fm_image: model.DataArray,
path: str) -> Dict[str, Any]:
"""Run 3DCT Multi-point correlation between FIB and FM images.
:param fib_coords: the FIB coordinates (n, (x, y)) (in pixels, origin at top left)
:param fm_coords: the FM coordinates (n, (x, y, z)) (in pixels, origin at top left)
:param poi_coords: the point of interest coordinates (1, (x, y, z)). Expects only one point of interest.
:param fib_image: the FIB image (YX)
:param fm_image: the FM image (CTZTX, CZYX or ZYX)
:param path: the path to save the results
:return: the correlation results
output:
error:
delta_2d: reprojection error between 3D and 2D coordinates
reprojected_3d: 3D coordinates reprojected to 2D
mean_absolute_error: mean absolute error of the transformation (x, y)
rms_error: root mean square error of the transformation
poi: list of transformed point of interest coordinates
image_px: coordinates in image pixels (0, 0 top left)
px: coordinates in microscope image pixels (0, 0 image center)
px_um: coordinates in microscope image meters (0, 0 image center)
transformation:
rotation_eulers: transformation rotation (eulers)
rotation_quaternion: transformation rotation (quaternion)
scale: transformation scale
translation_around_rotation_center: transformation translation
"""

# fib coordinates need to be x, y, z for 3DCT
if fib_coords.shape[-1] == 2:
fib_coords = numpy.column_stack((fib_coords, numpy.zeros(fib_coords.shape[0])))

# coordinates need to be float32 for 3DCT
fib_coords = fib_coords.astype(numpy.float32)
fm_coords = fm_coords.astype(numpy.float32)

# get first channel only, assume all channels are the same shape
if fm_image.ndim == 4:
fm_image = fm_image[0, :, :, :]
if fm_image.ndim == 5:
fm_image = fm_image[0, 0, :, :, :]

# get rotation center
halfmax_dim = int(max(fm_image.shape) * 0.5)
rotation_center = (halfmax_dim, halfmax_dim, halfmax_dim)

# get fib pixel size (meters)
fib_pixel_size = fib_image.metadata[model.MD_PIXEL_SIZE][0]

# fib image shape minus metadata, fib_pixelsize (microns), fm_image_shape
image_props = [fib_image.shape, fib_pixel_size * 1e6, fm_image.shape]

assert fm_coords.dtype == numpy.float32, "FM coordinates must be float32"
assert fib_coords.dtype == numpy.float32, "FIB coordinates must be float32"
assert fm_coords.shape[-1] == 3, "FM coordinates must be 3D (x, y, z)"
assert fib_coords.shape[-1] == 3, "FIB coordinates must be 3D (x, y, z)"
assert fib_coords.shape == fm_coords.shape, "FIB and FM coordinates must have the same shape"
assert fib_image.ndim == 2, "FIB Image must be 2D"
assert fm_image.ndim == 3, "FM Image must be 3D"
assert fib_pixel_size is not None, "FIB Pixel Size must be set"

logging.debug(
f"Running 3DCT correlation with FIB image shape: {fib_image.shape}, FM image shape: {fm_image.shape}"
)

# run correlation
correlation_results = run_correlation(
fib_coords=fib_coords,
fm_coords=fm_coords,
poi_coords=poi_coords,
image_props=image_props,
rotation_center=rotation_center,
path=path,
)

return correlation_results

def get_reprojected_poi_coordinate(correlation_results: dict) -> Tuple[float, float]:
"""Get the the point of interest coordinate from correlation data
and convert from micrometers to meters in the microscope image coordinate system.
The coordinate is centred at the image centre (x+ -> right, y+ -> up).
:param correlation_results: the correlation results
:return: the point of interest coordinate in meters
"""
# get the point of interest coordinate (in microscope coordinates, in metres)
poi_coord = correlation_results["output"]["poi"][0]["px_um"]
poi_coord = (poi_coord[0] * 1e-6, poi_coord[1] * 1e-6)
return poi_coord
Copy link
Contributor

@K4rishma K4rishma Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is confusing that some coordinates are saved in pixels and some in m, while all of them are called as coords? I am not sure why poi_coord is in m and not in pixels

Copy link
Contributor Author

@patrickcleeve2 patrickcleeve2 Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the final output for the milling patterns, which have coordinates in metres. They are all coordinates, not sure if you have a better name?

Copy link
Contributor

@K4rishma K4rishma Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe if most of the coordinates are in pixels, by default the coords can represent pixels and for m it can be represented by pos , while the function name can explicitly tell i.e. in meters ? For e.g. get_poi_pos() or get_poi_coords_in_meters().

With that I would rename the function description and definitions to make the difference clear.


def parse_3dct_yaml_file(path: str) -> Tuple[float, float]:
"""Parse the 3DCT yaml file and extract the point of interest (POI)
in microscope image coordinates (um). Convert the coordinates to metres.
Note: only the first POI is extracted.
:param path: path to the 3DCT yaml file.
:return: The point of interest in microscope image coordinates (metres, centred at the image centre).
"""
with open(path, "r") as f:
data = yaml.safe_load(f)

pt = get_reprojected_poi_coordinate(data["correlation"])

return pt
113 changes: 113 additions & 0 deletions src/odemis/acq/align/test/tdct_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
Created on 12 Feb 2025

Copyright © Delmic

This file is part of Odemis.

Odemis is free software: you can redistribute it and/or modify it under the terms
of the GNU General Public License version 2 as published by the Free Software
Foundation.

Odemis is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
Odemis. If not, see http://www.gnu.org/licenses/.
"""
import os
import unittest

import numpy
from odemis import model
from odemis.acq.align import tdct

RESULTS_PATH = os.path.join(os.getcwd(), "correlation_data.yaml")

class TestTDCT(unittest.TestCase):

def tearDown(self):
if os.path.exists(RESULTS_PATH):
os.remove(RESULTS_PATH)

def test_convert_das_to_numpy_stack(self):
"""Test the conversion of DataArrays to numpy stack"""
nc, nz, ny, nx = 3, 10, 1000, 1000
data_2d = numpy.random.random((ny, nx))
data_3d = numpy.random.random((nz, ny, nx))
data_5d = numpy.random.random((1, 1, nz, ny, nx))

# Test 2D input
da_2d = model.DataArray(data_2d)
result_2d = tdct._convert_das_to_numpy_stack([da_2d])
self.assertEqual(result_2d.shape, (1, 1, ny, nx))
self.assertEqual(result_2d.ndim, 4)

# Test 3D input
da_3d = model.DataArray(data_3d)
result_3d = tdct._convert_das_to_numpy_stack([da_3d])
self.assertEqual(result_3d.shape, (1, nz, ny, nx))
self.assertEqual(result_3d.ndim, 4)

# Test 5D input
da_5d = model.DataArray(data_5d)
result_5d = tdct._convert_das_to_numpy_stack([da_5d])
self.assertEqual(result_5d.shape, (1, nz, ny, nx))
self.assertEqual(result_5d.ndim, 4)

# Test multiple channels
result_multi_3d = tdct._convert_das_to_numpy_stack([da_3d, da_3d, da_3d])
self.assertEqual(result_multi_3d.shape, (nc, nz, ny, nx))
self.assertEqual(result_multi_3d.ndim, 4)

result_multi_5d = tdct._convert_das_to_numpy_stack([da_5d, da_5d, da_5d])
self.assertEqual(result_multi_5d.shape, (nc, nz, ny, nx))
self.assertEqual(result_multi_5d.ndim, 4)

# Test invalid dimensions
data_1d = numpy.random.random(nx)
da_1d = model.DataArray(data_1d)
with self.assertRaises(AssertionError):
tdct._convert_das_to_numpy_stack([da_1d])

def test_run_tdct_correlation(self):
"""Run the TDCT correlation and validate the output"""
fib_coords = numpy.array(
[[100, 100],
[900, 100],
[900, 900],
[100, 900]], dtype=numpy.float32)
fm_coords = numpy.array(
[[100, 100, 3],
[1000, 100, 4],
[900, 1000, 8],
[100, 1000, 9]], dtype=numpy.float32)
poi_coords = numpy.array([[500, 500, 5]], dtype=numpy.float32)
fib_image = model.DataArray(numpy.zeros(shape=(1024, 1536)),
metadata={
model.MD_PIXEL_SIZE: (100e-9, 100e-9)})
fm_image = numpy.zeros(shape=(10, 1024, 1024))
path = os.getcwd()

# run the correlation
ret = tdct.run_tdct_correlation(fib_coords=fib_coords,
fm_coords=fm_coords,
poi_coords=poi_coords,
fib_image=fib_image,
fm_image=fm_image,
path=path)

self.assertTrue(isinstance(ret, dict))
self.assertTrue(RESULTS_PATH)

# extract the poi coordinate from the correlation results
poi = tdct.get_poi_coordinate(ret)

# check the poi coordinate match the expected value
poi_um = ret["output"]["poi"][0]["px_um"]

self.assertEqual(len(poi), 2)
self.assertAlmostEqual(poi[0], poi_um[0] * 1e-6)
self.assertAlmostEqual(poi[1], poi_um[1] * 1e-6)
Loading