Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Performance of Resolving Transcript Based Segmentations #197

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions sopa/segmentation/_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ def resolve(
if min_area > 0:
log.info(f"Cells whose area is less than {min_area} microns^2 will be removed")

patches_cells, adatas = _read_all_segmented_patches(patches_dirs, min_area)
geo_df, cells_indices, new_ids = _resolve_patches(patches_cells, adatas)
patch_ids, patches_cells, adatas = _read_all_segmented_patches(patches_dirs, min_area)
patch_id_to_centroid = sdata[SopaKeys.TRANSCRIPTS_PATCHES].centroid.apply(lambda x: (x.x, x.y)).to_dict()
geo_df, cells_indices, new_ids = _resolve_patches(
patch_ids=patch_ids, adatas=adatas, patches_cells=patches_cells, patch_centroids=patch_id_to_centroid
)

points_key = sdata[SopaKeys.TRANSCRIPTS_PATCHES][SopaKeys.POINTS_KEY].iloc[0]
points = sdata[points_key]
Expand Down Expand Up @@ -88,8 +91,9 @@ def resolve(

def _read_one_segmented_patch(
directory: str, min_area: float = 0, min_vertices: int = 4
) -> tuple[list[Polygon], AnnData]:
) -> tuple[int, list[Polygon], AnnData]:
directory: Path = Path(directory)
patch_id = int(directory.name)
id_as_string, polygon_file = _find_polygon_file(directory)

loom_file = directory / "segmentation_counts.loom"
Expand Down Expand Up @@ -123,7 +127,7 @@ def _keep_cell(ID: str | int):

geo_df = geo_df[geo_df.area > min_area]

return geo_df.geometry.values, adata[geo_df.index].copy()
return patch_id, geo_df.geometry.values, adata[geo_df.index].copy()


def _find_polygon_file(directory: Path) -> tuple[bool, Path]:
Expand All @@ -138,36 +142,42 @@ def _find_polygon_file(directory: Path) -> tuple[bool, Path]:
def _read_all_segmented_patches(
patches_dirs: list[str],
min_area: float = 0,
) -> tuple[list[list[Polygon]], list[AnnData]]:
) -> tuple[list[int], list[list[Polygon]], list[AnnData]]:
outs = [
_read_one_segmented_patch(path, min_area)
for path in tqdm(patches_dirs, desc="Reading transcript-segmentation outputs")
]

patches_cells, adatas = zip(*outs)
patch_ids, patches_cells, adatas = zip(*outs)

return patches_cells, adatas
return patch_ids, patches_cells, adatas


def _resolve_patches(
patches_cells: list[list[Polygon]], adatas: list[AnnData]
patch_ids: list[int],
adatas: list[AnnData],
patches_cells: list[list[Polygon]],
patch_centroids: dict[int, tuple[float, float]],
) -> tuple[gpd.GeoDataFrame, np.ndarray, np.ndarray]:
"""Resolve the segmentation conflits on the patches overlaps.
"""Resolve the segmentation conflicts on the patches overlaps.

Args:
patches_cells: List of polygons segmented on each patch
patch_ids: List of ids of the patches
adatas: List of AnnData objects corresponding to each patch
patches_cells: List of polygons segmented on each patch
patch_centroids: Centroids of the patches

Returns:
The new GeoDataFrame, the new cells indices (-1 for merged cells), and the ids of the merged cells.
"""
patch_ids = [adata.obs_names for adata in adatas]

patch_indices = np.arange(len(patches_cells)).repeat([len(cells) for cells in patches_cells])
per_patch_segment_ids = [adata.obs_names for adata in adatas]
per_cell_patch_indices = np.array(patch_ids).repeat([len(cells) for cells in patches_cells])
cells = [cell for cells in patches_cells for cell in cells]
segmentation_ids = np.array([cell_id for ids in patch_ids for cell_id in ids])
segmentation_ids = np.array([cell_id for ids in per_patch_segment_ids for cell_id in ids])

cells_resolved, cells_indices = solve_conflicts(cells, patch_indices=patch_indices, return_indices=True)
cells_resolved, cells_indices = solve_conflicts(
cells, patch_indices=per_cell_patch_indices, patch_centroids=patch_centroids, return_indices=True
)

existing_ids = segmentation_ids[cells_indices[cells_indices >= 0]]
new_ids = np.char.add("merged_cell_", np.arange((cells_indices == -1).sum()).astype(str))
Expand Down
22 changes: 21 additions & 1 deletion sopa/segmentation/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def solve_conflicts(
cells: list[Polygon] | gpd.GeoDataFrame,
threshold: float = 0.5,
patch_indices: np.ndarray | None = None,
patch_centroids: dict[int, tuple[float, float]] | None = None,
return_indices: bool = False,
) -> gpd.GeoDataFrame | tuple[gpd.GeoDataFrame, np.ndarray]:
"""Resolve segmentation conflicts (i.e. overlap) after running segmentation on patches
Expand All @@ -29,19 +30,38 @@ def solve_conflicts(
cells: List of cell polygons
threshold: When two cells are overlapping, we look at the area of intersection over the area of the smallest cell. If this value is higher than the `threshold`, the cells are merged
patch_indices: Patch from which each cell belongs.
patch_centroids: Centroids of each patch.
return_indices: If `True`, returns also the cells indices. Merged cells have an index of -1.

Returns:
Array of resolved cells polygons. If `return_indices`, it also returns an array of cell indices.
"""
cells = list(cells.geometry) if isinstance(cells, gpd.GeoDataFrame) else list(cells)
n_cells = len(cells)
resolved_indices = np.arange(n_cells)

assert n_cells > 0, "No cells was segmented, cannot continue"

if patch_centroids is not None and patch_indices is not None:
cells_gdf = gpd.GeoDataFrame({"cell_patch_index": patch_indices}, geometry=cells).reset_index(names="cell_id")
centroid_gdf = gpd.GeoDataFrame(
{"patch_index": list(patch_centroids.keys())},
geometry=[shapely.geometry.Point(coord) for coord in patch_centroids.values()],
)

joined_to_centroids = gpd.sjoin_nearest(
cells_gdf,
centroid_gdf,
)

joined_to_centroids = joined_to_centroids.drop_duplicates(subset=["cell_id"], keep="first")

cells = joined_to_centroids[
joined_to_centroids["patch_index"] == joined_to_centroids["cell_patch_index"]
].geometry.to_list()

tree = shapely.STRtree(cells)
conflicts = tree.query(cells, predicate="intersects")
resolved_indices = np.arange(len(cells))

if patch_indices is not None:
conflicts = conflicts[:, patch_indices[conflicts[0]] != patch_indices[conflicts[1]]].T
Expand Down
16 changes: 16 additions & 0 deletions tests/test_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@ def test_solve_conflict(cells: gpd.GeoDataFrame):

res = solve_conflicts(list(cells.geometry) + other_cells)
assert all(isinstance(cell, Polygon) for cell in res.geometry)


def test_solve_conflict_patches(cells: gpd.GeoDataFrame):
tile_overlap = 50
other_cells = [translate(cell, cells.total_bounds[2] - tile_overlap, 0) for cell in cells.geometry]

res = solve_conflicts(
cells=list(cells.geometry) + other_cells,
patch_indices=np.array([0] * len(cells) + [1] * len(other_cells)),
patch_centroids={
0: (cells.total_bounds[2] / 2, cells.total_bounds[3] / 2),
1: (cells.total_bounds[2] / 2 + (cells.total_bounds[2] - tile_overlap), cells.total_bounds[3] / 2),
},
)
assert all(isinstance(cell, Polygon) for cell in res.geometry)
assert len(res) < len(cells) + len(other_cells)