Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample.region_centered_grid #60

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 78 additions & 26 deletions src/faim_wako_searchfirst/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

import numpy as np
from numpy import ndarray
from skimage.measure import block_reduce, regionprops
from skimage.filters.rank import maximum
from skimage.measure import block_reduce, label, regionprops
from skimage.morphology import rectangle


def dense_grid(
Expand All @@ -30,14 +32,14 @@ def dense_grid(
c = csv.writer(csv_file)
count = 0
it = np.nditer(downscaled, flags=["multi_index"])
for label in it:
if label > 0:
for label_value in it:
if label_value > 0:
c.writerow([count] + _grid_coordinate(it.multi_index, binning_factor))
count += 1


def _grid_coordinate(index, factor):
return [(index[1] + 0.5) * factor, (index[0] + 0.5) * factor]
return [(index[1] + 0.5) * factor, (index[0] + 0.5) * factor] # TODO add 0.5 here?


def grid_overlap(
Expand Down Expand Up @@ -97,27 +99,11 @@ def _filter_points(points, weights, y_threshold, x_threshold):
return keep_indices


def object_centered_grid(
def _sample_grid_on_regions(
labeled_img: ndarray,
path: Path,
mag_first_pass: float,
mag_second_pass: float,
overlap_ratio: float = 0.0,
tile_size_y: float,
tile_size_x: float,
):
"""Sample each labeled object with a centered grid of tiles.

If the object fits into a single field of view, record just the centroid coordinate.
Otherwise, compute how many tiles are required to fit the object, and record only
those grid coordinates that cover the object mask.

For objects where the resulting fields of view would be overlapping,
only keep the largest object and discard all others.
"""
factor = mag_first_pass / mag_second_pass
shift_percent = 1.0 - overlap_ratio
tile_size_y = labeled_img.shape[0] * factor * shift_percent
tile_size_x = labeled_img.shape[1] * factor * shift_percent

props = regionprops(label_image=labeled_img)
labels = []
coordinates = []
Expand Down Expand Up @@ -152,6 +138,35 @@ def object_centered_grid(
coordinates.extend(valid_points)
areas.extend([p.area] * len(valid_points))
labels.extend([p.label] * len(valid_points))
return coordinates, areas, labels


def object_centered_grid(
labeled_img: ndarray,
path: Path,
mag_first_pass: float,
mag_second_pass: float,
overlap_ratio: float = 0.0,
):
"""Sample each labeled object with a centered grid of tiles.

If the object fits into a single field of view, record just the centroid coordinate.
Otherwise, compute how many tiles are required to fit the object, and record only
those grid coordinates that cover the object mask.

For objects where the resulting fields of view would be overlapping,
only keep the largest object and discard all others.
"""
factor = mag_first_pass / mag_second_pass
shift_percent = 1.0 - overlap_ratio
tile_size_y = labeled_img.shape[0] * factor * shift_percent
tile_size_x = labeled_img.shape[1] * factor * shift_percent

coordinates, areas, labels = _sample_grid_on_regions(
labeled_img=labeled_img,
tile_size_y=tile_size_y,
tile_size_x=tile_size_x,
)

# filter overlapping coordinates
keep_points = _filter_points(
Expand All @@ -162,10 +177,47 @@ def object_centered_grid(
)

coordinates = np.array(coordinates)[keep_points]
areas = np.array(areas)[keep_points]
labels = np.array(labels)[keep_points]

with open(path, "w", newline="") as csv_file:
c = csv.writer(csv_file)
for label, point in zip(labels, coordinates):
c.writerow([label, point[1], point[0]])
for label_value, point in zip(labels, coordinates):
c.writerow([label_value, *reversed(point)])


def region_centered_grid(
labeled_img: ndarray,
path: Path,
mag_first_pass: float,
mag_second_pass: float,
overlap_ratio: float = 0.0,
):
"""Sample optimal grid for each region of objects that are close to each other.

The grid is computed centered on each region, with an optional specified overlap.
"""
factor = mag_first_pass / mag_second_pass
shift_percent = 1.0 - overlap_ratio
tile_size_y = labeled_img.shape[0] * factor * shift_percent
tile_size_x = labeled_img.shape[1] * factor * shift_percent
# dilate
mask = labeled_img > 0
footprint = rectangle(
np.ceil(tile_size_y).astype(int), np.ceil(tile_size_x).astype(int)
) # , decomposition="separable"
dilated = maximum(image=mask.astype(np.uint8), footprint=footprint)
# label
regions = label(dilated)
# reconstruct
reconstructed = np.where(mask, regions, 0)

coordinates, _, labels = _sample_grid_on_regions(
labeled_img=reconstructed,
tile_size_y=tile_size_y,
tile_size_x=tile_size_x,
)

with open(path, "w", newline="") as csv_file:
c = csv.writer(csv_file)
for label_value, point in zip(labels, coordinates):
c.writerow([label_value, *reversed(point)])
33 changes: 32 additions & 1 deletion tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd
import pytest
from faim_wako_searchfirst.sample import object_centered_grid
from faim_wako_searchfirst.sample import object_centered_grid, region_centered_grid
from skimage.io import imread


Expand All @@ -32,3 +32,34 @@ def test_object_centered_grid(_label_image, tmp_path):
print(centers_table)
assert len(centers_table) == 7
assert centers_table[0].unique().tolist() == [1, 2, 3]


def test_region_centered_grid(_label_image, tmp_path):
"""Test object-centered grid sampling."""
assert _label_image.shape == (200, 200)
csv_path = tmp_path / "points_region_centered.csv"
region_centered_grid(
labeled_img=_label_image,
path=csv_path,
mag_first_pass=4,
mag_second_pass=20,
overlap_ratio=0.2,
)
assert csv_path.exists()
centers_table = pd.read_csv(csv_path, header=None)
print(centers_table)
assert len(centers_table) == 12
assert centers_table[0].unique().tolist() == [1, 2, 3]
assert centers_table[0].tolist() == [1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3]
assert centers_table.iloc[0].values.flatten().tolist() == pytest.approx([1, 18.5, 14.5])
assert centers_table.iloc[1].values.flatten().tolist() == pytest.approx([1, 50.5, 14.5])
assert centers_table.iloc[2].values.flatten().tolist() == pytest.approx([2, 28.0, 53.0])
assert centers_table.iloc[3].values.flatten().tolist() == pytest.approx([2, 60.0, 53.0])
assert centers_table.iloc[4].values.flatten().tolist() == pytest.approx([2, 92.0, 21.0])
assert centers_table.iloc[5].values.flatten().tolist() == pytest.approx([2, 92.0, 53.0])
assert centers_table.iloc[6].values.flatten().tolist() == pytest.approx([3, 136.5, 104.5])
assert centers_table.iloc[7].values.flatten().tolist() == pytest.approx([3, 136.5, 136.5])
assert centers_table.iloc[8].values.flatten().tolist() == pytest.approx([3, 136.5, 168.5])
assert centers_table.iloc[9].values.flatten().tolist() == pytest.approx([3, 168.5, 104.5])
assert centers_table.iloc[10].values.flatten().tolist() == pytest.approx([3, 168.5, 136.5])
assert centers_table.iloc[11].values.flatten().tolist() == pytest.approx([3, 168.5, 168.5])
Loading