Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions scripts/align_ish_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import json
import logging
import pathlib
import warnings
from argparse import ArgumentParser
from os import makedirs
from os.path import join

import nrrd
import numpy as np
import scipy
from atldld.sync import DatasetDownloader
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.transform import resize
from tqdm import tqdm

from atlalign.base import DisplacementField
from atlalign.non_ml import antspy_registration
from atlalign.volume import CoronalInterpolator, GappedVolume

warnings.filterwarnings("ignore")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should identify where exactly this warning is raised and do one of the following

  • Filter the warning inside of a context manager
  • Fix the cause of the warning

logger = logging.getLogger()

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s")




class SagittalInterpolator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would imagine this one is a just a copy of the logic from CoronalInterpolator.

Possible actions

  • No action
  • Write a more general linear interpolator and make sure CoronalInterpolator and SagittalInterpolator are special cases
  • Outsource the interpolation to atlinter

"""Interpolator that works pixel by pixel in the coronal dimension."""

def __init__(self, kind="linear", fill_value=0, bounds_error=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well type annotate the entire script. Should be easy

"""Construct."""
self.kind = kind
self.fill_value = fill_value
self.bounds_error = bounds_error

def interpolate(self, gv):
"""Interpolate.

Note that some section images might have pixels equal to np.nan. In this case these pixels are skipped in the
interpolation.

Parameters
----------
gv : GappedVolume
Instance of the ``GappedVolume`` to be interpolated.

Returns
-------
final_volume : np.ndarray
Array of shape (528, 320, 456) that holds the entire interpolated volume without gaps.

"""
grid = np.array(range(456))
final_volume = np.empty((*gv.shape, len(grid)))

for r in range(gv.shape[0]):
for c in range(gv.shape[1]):
x_pixel, y_pixel = zip(
*[
(s, img[r, c])
for s, img in zip(gv.sn, gv.imgs)
if not np.isnan(img[r, c])
]
)

f = scipy.interpolate.interp1d(
x_pixel,
y_pixel,
kind=self.kind,
bounds_error=self.bounds_error,
fill_value=self.fill_value,
)
try:
final_volume[r, c, :] = f(grid)
except Exception as e:
logging.warning(e)

return final_volume


def download_and_align_marker(
dataset_id, nvol, model_gl, header,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely run some linters/ code formatters on it.

all_sn=None, output_filename=None,
include_expr=True,
is_sagittal=False,
resolution=25.0
):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring does not seem to be formatted based on any widely adopted conventions (numpy, google) so we should definitely modify it (I guess we can go for numpy like in other projects)

Also, I wouldn't mind if it contained more details.

Download and align coronal images of mouse brain expressing a genetic marker
according to a provided nissl volume.
The experiment images will be downloaded from the Allen Institute website
according to the provided dataset id.

Parameters:
dataset_id: Id of the Allen experiment
nvol: 3D numpy ndarray Nissl volume
model_gl: Results of the global warping machine learning
header: header for the nrrd file
all_sn: Results of the local warping machine learning
output_filename: Name of the file where the dataset will be stored.
resolution: Voxel size for the nissl volume in um
"""
is_sagitall = False # TODO

slice_shape = nvol.shape[1:]
downloader = DatasetDownloader(dataset_id, include_expression=include_expr, downsample_img=2)
downloader.fetch_metadata()
allen_gen = downloader.run()
all_registered = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid confusion I would use the term synchronization to make it clear that we apply the transformation AB gives us.

all_downsampled = []
all_expressions = []

for (image_id, p, img, img_exp, df) in tqdm(allen_gen):
img_preprocessed = rgb2gray(255 - img)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the problem with this preprocessing is that the background will not be perfectly black.

Another approach would be to define a threshold value k and set all pixels with intensities lower than k to 0.

Alternatively, we can use the expression image to perfectly separate out the background.

if include_expr:
expr_preprocessed = rgb2gray(img_exp)
img_binary = (expr_preprocessed > threshold_otsu(expr_preprocessed)) * 1
expr_preprocessed = img_binary.astype("uint8")
all_expressions.append(df.warp(expr_preprocessed))
all_registered.append(df.warp(img_preprocessed))
all_downsampled.append(resize(img_preprocessed, slice_shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should not be calling resize here since we might mess up the synchronization from Allen Brain. One can control the microns per pixel in the constructor of DatasetDownloader via the downsample_ref.

if not use_manual:
all_sn.append(int(p // resolution))
if is_sagittal:
x_shape = nvol.shape[2]
else:
x_shape = nvol.shape[0]
for i, sn in enumerate(all_sn):
if sn >= x_shape:
all_sn[i] = x_shape - len(all_sn) + i
if not is_sagittal:
# Prepare input
inputs = np.concatenate(
[nvol[all_sn][..., np.newaxis], np.array(all_registered)[..., np.newaxis]],
axis=-1,
)

# Forward pass
_, deltas_xy = model_gl.predict(inputs)
# Warp the moving images
all_dl = [
DisplacementField(deltas_xy[i, ..., 0], deltas_xy[i, ..., 1]).warp(img_mov)
for i, img_mov in enumerate(all_registered)
]
else:
all_dl = np.copy(all_registered)
tot_sn = np.copy(all_sn).tolist()
for sn, img_mov in zip(tot_sn, all_registered):
if sn < nvol.shape[2] // 2:
if (sn + nvol.shape[2] // 2) not in all_sn:
all_sn.append(sn + nvol.shape[2] // 2)
all_dl = np.vstack((all_dl, np.copy(img_mov)[None, :, :]))
else:
if (nvol.shape[2] - sn) not in all_sn:
all_sn.append(nvol.shape[2] - sn)
all_dl = np.vstack((all_dl, np.copy(img_mov)[None, :, :]))
Comment on lines +151 to +159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do exactly?? From what I understand this is the logic for sagittal slices.


all_ib = []
for i, (img_mov, sn) in tqdm(enumerate(zip(all_dl, all_sn))):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so it seems like we first do deep learning registration followed by intensity based. I guess we don't question it and just assume that is the best setup.

if is_sagittal:
df, _ = antspy_registration(nvol[:, :, sn], img_mov)
else:
df, _ = antspy_registration(nvol[sn], img_mov)
if include_expr:
all_ib.append(df.warp(all_expressions[i]))
else:
all_ib.append(df.warp(img_mov))

gv = GappedVolume(all_sn, all_ib)

if is_sagittal:
ci = SagittalInterpolator(kind="linear")
else:
ci = CoronalInterpolator(kind="linear")
final_volume = ci.interpolate(gv)

return final_volume


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"nissl_path",
type=pathlib.Path,
help="Path to the Nissl volume."
)
parser.add_argument(
"local_model_path",
type=pathlib.Path,
help="Path to the local deep learning model."
)
parser.add_argument(
"global_model_path",
type=pathlib.Path,
help="Path to the global deep learning model."
)
parser.add_argument(
"genes",
type=str,
help="Comma separated list of gene ids to download and align."
)
parser.add_argument(
"output_path",
type=pathlib.Path,
help="Path to the folder where the results will be stored."
)
parser.add_argument(
"-e",
"--include-expression",
action="store_true",
help="If True, we also download and align expression images."
)
args = parser.parse_args()

# imports
from unittest.mock import Mock

from atlalign.ml_utils import load_model, merge_global_local

logger.info("Aligning markers images to the Nissl volume.")

nvol, header = nrrd.read(args.nissl_path)
nvol = nvol / nvol.max()

genelist = args.genes.split(",")

local_model = load_model(args.local_model_path)
global_model = load_model(args.global_model_path)
model_gl = merge_global_local(global_model, local_model)

args.output_path.mkdir(exist_ok=True, parents=True)

for dataset_id in genelist:
logger.info(f"Downloading and aligning {dataset_id=}")
# temp
download_and_align_marker = Mock(return_value=np.zeros((528, 320, 456)))
volume = download_and_align_marker(
dataset_id,
nvol,
model_gl,
include_expr=args.include_expression,
)

gene_folder = args.output_path / dataset_id
gene_folder.mkdir(exist_ok=True, parents=True)

volume_path = gene_folder / "volume.nrrd"

nrrd.write(str(volume_path), volume, header=header)
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ deps =
black==22.3.0
commands =
flake8 setup.py {[tox]source} tests
isort --honor-noqa --profile=black --check setup.py {[tox]source} tests
isort --honor-noqa --profile=black --check setup.py {[tox]source} scripts tests
pydocstyle {[tox]source}
black --check setup.py {[tox]source} tests

Expand All @@ -38,7 +38,7 @@ deps =
isort
black
commands =
isort --honor-noqa --profile=black setup.py {[tox]source} tests
isort --honor-noqa --profile=black setup.py {[tox]source} scripts tests
black setup.py {[tox]source} tests

[testenv:docs]
Expand Down