Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"deepcopy",
"sanitize_table",
"sanitize_name",
"subset_sdata_by_table_mask",
]

from spatialdata import dataloader, datasets, models, transformations
Expand All @@ -76,6 +77,7 @@
match_element_to_table,
match_sdata_to_table,
match_table_to_element,
subset_sdata_by_table_mask,
)
from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query
from spatialdata._core.spatialdata import SpatialData
Expand Down
123 changes: 123 additions & 0 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import dask.array as da
import numpy as np
import pandas as pd
import xarray as xr
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from numpy.typing import NDArray
from xarray import DataArray, DataTree

from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -1017,3 +1019,124 @@ def get_values(
return df

raise ValueError(f"Unknown origin {origin}")


def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
# Use apply_ufunc for efficient processing
# Create a copy to avoid modifying read-only array
result = block.copy()
result[np.isin(result, ids_to_remove)] = 0
return result


def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray:
processed = xr.apply_ufunc(
partial(_mask_block, ids_to_remove=ids_to_remove),
image,
input_core_dims=[["y", "x"]],
output_core_dims=[["y", "x"]],
vectorize=True,
dask="parallelized",
output_dtypes=[image.dtype],
dataset_fill_value=0,
dask_gufunc_kwargs={"allow_rechunk": True},
)

# Create a new DataArray to ensure persistence
return xr.DataArray(
data=processed.data,
coords=image.coords,
dims=image.dims,
attrs=image.attrs.copy(), # Preserve all attributes
)


def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]:
scales = list(labels_element.keys())

# Calculate relative scale factors between consecutive scales
scale_factors = []
for i in range(len(scales) - 1):
y_size_current = labels_element[scales[i]].image.shape[0]
x_size_current = labels_element[scales[i]].image.shape[1]
y_size_next = labels_element[scales[i + 1]].image.shape[0]
x_size_next = labels_element[scales[i + 1]].image.shape[1]
y_factor = y_size_current / y_size_next
x_factor = x_size_current / x_size_next

scale_factors.append((y_factor, x_factor))

return scale_factors


@singledispatch
def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any:
raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}")


@_filter_by_instance_ids.register(DataArray)
def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray:
del instance_key
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean why del instance_key? It is to explicitly clarify that I won't be using it for this dispatch

return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove))


@_filter_by_instance_ids.register(DataTree)
def _(element: DataTree, ids_to_remove: list[int], instance_key: str) -> DataTree:
# we extract the info to just reconstruct
# the DataTree after filtering the max scale
del instance_key
max_scale = list(element.keys())[0]
scale_factors_temp = _get_scale_factors(element)
scale_factors = [int(sf[0]) for sf in scale_factors_temp]

return Labels2DModel.parse(
data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove),
scale_factors=scale_factors,
)


def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData:
"""Subset the annotated elements of a SpatialData object by a table and a mask.

The mask is applied to the table and the annotated elements are subsetted
by the instance ids in the table.
This function returns a new SpatialData object with the subsetted elements.
Elements that are not annotated by the table are not included in the returned SpatialData object.
The element models that are
supported are :class:`spatialdata.models.Labels2DModel`,
:class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`.

Parameters
----------
sdata
The SpatialData object to subset.
table_name
The name of the table to apply the mask to.
mask
Boolean mask to apply to the table which is the same length as the number of rows in the table.

Returns
-------
The subsetted SpatialData object.
"""
table = sdata.tables.get(table_name)
if table is None:
raise ValueError(f"Table {table_name} not found in SpatialData object.")

subset_table = table[mask]
sdata.tables[table_name] = subset_table
_, _, instance_key = get_table_keys(subset_table)
annotated_regions = SpatialData.get_annotated_regions(table)
removed_instance_ids = list(np.unique(table.obs[instance_key][~mask]))

filtered_elements = {}
for reg in annotated_regions:
elem = sdata.get(reg)
model = get_model(elem)
if model is Labels2DModel:
filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key)
elif model in [PointsModel, ShapesModel]:
element_dict, _ = match_element_to_table(sdata, element_name=reg, table_name=table_name)
filtered_elements[reg] = element_dict[reg]

return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table})
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import pytest

from spatialdata import concatenate, subset_sdata_by_table_mask
from spatialdata._core.query.relational_query import _filter_by_instance_ids
from spatialdata.datasets import blobs_annotating_element


def test_filter_labels2dmodel_by_instance_ids() -> None:
sdata = blobs_annotating_element("blobs_labels")
labels_element = sdata["blobs_labels"]
all_instance_ids = sdata.tables["table"].obs["instance_id"].unique()
filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id")

# because 0 is the background, we expect the filtered ids to be the instance ids that are not 0
filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {
0,
}
preserved_ids = np.unique(labels_element.data.compute())
assert filtered_ids == (set(all_instance_ids) - {2, 3})
# check if there is modification of the original labels
assert set(preserved_ids) == set(all_instance_ids) | {0}

sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels"
sdata.tables["table"].obs.region = "blobs_multiscale_labels"
labels_element = sdata["blobs_multiscale_labels"]
filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id")

for scale in labels_element:
filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {
0,
}
preserved_ids = np.unique(labels_element[scale].image.compute())
assert filtered_ids == (set(all_instance_ids) - {2, 3})
# check if there is modification of the original labels
assert set(preserved_ids) == set(all_instance_ids) | {0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use np.testing instead

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but these are all simple set comparisons. would have to order then compare the matrices which seems convoluted. But if you think the same still I can do it



def test_subset_sdata_by_table_mask() -> None:
sdata = concatenate(
{
"labels": blobs_annotating_element("blobs_labels"),
"shapes": blobs_annotating_element("blobs_circles"),
"points": blobs_annotating_element("blobs_points"),
"multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"),
},
concatenate_tables=True,
)
third_elems = sdata.tables["table"].obs["instance_id"] == 3
subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems)

assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"}
assert set(subset_sdata.points.keys()) == {"blobs_points-points"}
assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"}

labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0}
assert labels_remaining_ids == {3}

for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]:
ms_labels_remaining_ids = set(
np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())
) - {0}
assert ms_labels_remaining_ids == {3}

points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0}
assert points_remaining_ids == {3}

shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0}
assert shapes_remaining_ids == {3}


def test_subset_sdata_by_table_mask_with_no_annotated_elements() -> None:
with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."):
sdata = blobs_annotating_element("blobs_labels")
_ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3)


def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None:
with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"):
_filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id")
Loading