Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 132 additions & 64 deletions src/spatialdata_io/readers/xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import ome_types
import packaging.version
import pandas as pd
import pyarrow.compute as pc
import pyarrow.parquet as pq
import tifffile
import zarr
from dask.dataframe import read_parquet
from dask_image.imread import imread
from geopandas import GeoDataFrame
from pyarrow import Table
from shapely import GeometryType, Polygon, from_ragged_array
from spatialdata import SpatialData
from spatialdata._core.query.relational_query import get_element_instances
Expand All @@ -44,6 +44,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping

import pyarrow as pa
from anndata import AnnData
from spatialdata._types import ArrayLike

Expand All @@ -69,6 +70,7 @@ def xenium(
morphology_focus: bool = True,
aligned_images: bool = True,
cells_table: bool = True,
n_jobs: int | None = None,
gex_only: bool = True,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
image_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
Expand Down Expand Up @@ -121,6 +123,10 @@ def xenium(
`False` and use the `xenium_aligned_image` function directly.
cells_table
Whether to read the cell annotations in the `AnnData` table.
n_jobs
.. deprecated::
``n_jobs`` is not used anymore and will be removed in a future release. The reading time of shapes is now
greatly improved and does not require parallelization.
gex_only
Whether to load only the "Gene Expression" feature type.
imread_kwargs
Expand Down Expand Up @@ -153,6 +159,13 @@ def xenium(
... )
>>> sdata.write("path/to/data.zarr")
"""
if n_jobs is not None:
warnings.warn(
"The `n_jobs` parameter is deprecated and will be removed in a future release. "
"The reading time of shapes is now greatly improved and does not require parallelization.",
DeprecationWarning,
stacklevel=2,
)
image_models_kwargs, labels_models_kwargs = _initialize_raster_models_kwargs(
image_models_kwargs, labels_models_kwargs
)
Expand Down Expand Up @@ -188,18 +201,42 @@ def xenium(
else:
table = None

# open cells.zarr.zip once and reuse across all functions that need it
cells_zarr: zarr.Group | None = None
need_cells_zarr = (
nucleus_labels
or cells_labels
or (version is not None and version >= packaging.version.parse("2.0.0") and table is not None)
)
if need_cells_zarr:
cells_zarr_store = zarr.storage.ZipStore(path / XeniumKeys.CELLS_ZARR, read_only=True)
cells_zarr = zarr.open(cells_zarr_store, mode="r")

# pre-compute cell_id strings from the zarr once, to avoid redundant conversion
# in both _get_cells_metadata_table_from_zarr and _get_labels_and_indices_mapping.
cells_zarr_cell_id_str: np.ndarray | None = None
if cells_zarr is not None and version is not None and version >= packaging.version.parse("1.3.0"):
cell_id_raw = cells_zarr["cell_id"][...]
cell_id_prefix, dataset_suffix = cell_id_raw[:, 0], cell_id_raw[:, 1]
cells_zarr_cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id_prefix, dataset_suffix)

if version is not None and version >= packaging.version.parse("2.0.0") and table is not None:
cell_summary_table = _get_cells_metadata_table_from_zarr(path, XeniumKeys.CELLS_ZARR, specs)
if not cell_summary_table[XeniumKeys.CELL_ID].equals(table.obs[XeniumKeys.CELL_ID]):
assert cells_zarr is not None
cell_summary_table = _get_cells_metadata_table_from_zarr(cells_zarr, specs, cells_zarr_cell_id_str)
try:
_assert_arrays_equal_sampled(
cell_summary_table[XeniumKeys.CELL_ID].values, table.obs[XeniumKeys.CELL_ID].values
)
except AssertionError:
warnings.warn(
'The "cell_id" column in the cells metadata table does not match the "cell_id" column in the annotation'
" table. This could be due to trying to read a new version that is not supported yet. Please "
"report this issue.",
UserWarning,
stacklevel=2,
)
table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL]
table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT]
table.obs[XeniumKeys.Z_LEVEL] = cell_summary_table[XeniumKeys.Z_LEVEL].values
table.obs[XeniumKeys.NUCLEUS_COUNT] = cell_summary_table[XeniumKeys.NUCLEUS_COUNT].values

polygons = {}
labels = {}
Expand All @@ -220,6 +257,8 @@ def xenium(
mask_index=0,
labels_name="nucleus_labels",
labels_models_kwargs=labels_models_kwargs,
cells_zarr=cells_zarr,
cell_id_str=cells_zarr_cell_id_str,
)
if cells_labels:
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
Expand All @@ -228,9 +267,15 @@ def xenium(
mask_index=1,
labels_name="cell_labels",
labels_models_kwargs=labels_models_kwargs,
cells_zarr=cells_zarr,
cell_id_str=cells_zarr_cell_id_str,
)
if cell_labels_indices_mapping is not None and table is not None:
if not pd.DataFrame.equals(cell_labels_indices_mapping["cell_id"], table.obs[str(XeniumKeys.CELL_ID)]):
try:
_assert_arrays_equal_sampled(
cell_labels_indices_mapping["cell_id"].values, table.obs[str(XeniumKeys.CELL_ID)].values
)
except AssertionError:
warnings.warn(
"The cell_id column in the cell_labels_table does not match the cell_id column derived from the "
"cell labels data. This could be due to trying to read a new version that is not supported yet. "
Expand All @@ -239,7 +284,7 @@ def xenium(
stacklevel=2,
)
else:
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"]
table.obs["cell_labels"] = cell_labels_indices_mapping["label_index"].values
if not cells_as_circles:
table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] = "cell_labels"

Expand All @@ -248,7 +293,7 @@ def xenium(
path,
XeniumKeys.NUCLEUS_BOUNDARIES_FILE,
specs,
idx=table.obs[str(XeniumKeys.CELL_ID)].copy(),
idx=None,
)

if cells_boundaries:
Expand Down Expand Up @@ -389,6 +434,13 @@ def filter(self, record: logging.LogRecord) -> bool:
return _set_reader_metadata(sdata, "xenium")


def _assert_arrays_equal_sampled(a: ArrayLike, b: ArrayLike, n: int = 100) -> None:
"""Assert two arrays are equal by checking a random sample of entries."""
assert len(a) == len(b), f"Array lengths differ: {len(a)} != {len(b)}"
idx = np.random.default_rng(0).choice(len(a), size=min(n, len(a)), replace=False)
np.testing.assert_array_equal(np.asarray(a[idx]), np.asarray(b[idx]))


def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
if isinstance(cell_id_column.iloc[0], bytes):
return cell_id_column.str.decode("utf-8")
Expand All @@ -403,28 +455,35 @@ def _get_polygons(
specs: dict[str, Any],
idx: pd.Series | None = None,
) -> GeoDataFrame:
# seems to be faster than pd.read_parquet
df = pq.read_table(path / file).to_pandas()
cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
x = df[XeniumKeys.BOUNDARIES_VERTEX_X].to_numpy()
y = df[XeniumKeys.BOUNDARIES_VERTEX_Y].to_numpy()
# Use PyArrow compute to avoid slow .to_numpy() on Arrow-backed strings in pandas >= 3.0
# The original approach was:
# df = pq.read_table(path / file).to_pandas()
# cell_ids = df[XeniumKeys.CELL_ID].to_numpy()
# which got slow with pandas >= 3.0 (Arrow-backed string .to_numpy() is ~100x slower).
# By doing change detection in Arrow, we avoid allocating Python string objects for all rows.
table = pq.read_table(path / file)
cell_id_col = table.column(str(XeniumKeys.CELL_ID))

x = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_X)).to_numpy()
y = table.column(str(XeniumKeys.BOUNDARIES_VERTEX_Y)).to_numpy()
coords = np.column_stack([x, y])

change_mask = np.concatenate([[True], cell_ids[1:] != cell_ids[:-1]])
n = len(cell_id_col)
change_mask = np.empty(n, dtype=bool)
change_mask[0] = True
change_mask[1:] = pc.not_equal(cell_id_col.slice(0, n - 1), cell_id_col.slice(1)).to_numpy(zero_copy_only=False)
group_starts = np.where(change_mask)[0]
group_ends = np.concatenate([group_starts[1:], [len(cell_ids)]])
group_ends = np.concatenate([group_starts[1:], [n]])

# sanity check
n_unique_ids = len(df[XeniumKeys.CELL_ID].drop_duplicates())
n_unique_ids = pc.count_distinct(cell_id_col).as_py()
if len(group_starts) != n_unique_ids:
raise ValueError(
f"In {file}, rows belonging to the same polygon must be contiguous. "
f"Expected {n_unique_ids} group starts, but found {len(group_starts)}. "
f"This indicates non-consecutive polygon rows."
)

unique_ids = cell_ids[group_starts]

# offsets for ragged array:
# offsets[0] (ring_offsets): describing to which rings the vertex positions belong to
# offsets[1] (geom_offsets): describing to which polygons the rings belong to
Expand All @@ -433,22 +492,16 @@ def _get_polygons(

geoms = from_ragged_array(GeometryType.POLYGON, coords, offsets=(ring_offsets, geom_offsets))

index = _decode_cell_id_column(pd.Series(unique_ids))
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)

version = _parse_version_of_xenium_analyzer(specs)
if version is not None and version < packaging.version.parse("2.0.0"):
assert idx is not None
assert len(idx) == len(geo_df)
assert np.array_equal(index.values, idx.values)
# idx is not None for the cells and None for the nuclei (for xenium(cells_table=False) is None for both
if idx is not None:
# Cell IDs already available from the annotation table
assert len(idx) == len(group_starts), f"Expected {len(group_starts)} cell IDs, got {len(idx)}"
geo_df = GeoDataFrame({"geometry": geoms}, index=idx.values)
else:
if np.unique(geo_df.index).size != len(geo_df):
warnings.warn(
"Found non-unique polygon indices, this will be addressed in a future version of the reader. For the "
"time being please consider merging polygons with non-unique indices into single multi-polygons.",
UserWarning,
stacklevel=2,
)
# Fall back to extracting unique cell IDs from parquet (slow for large_string columns).
unique_ids = cell_id_col.filter(change_mask).to_pylist()
index = _decode_cell_id_column(pd.Series(unique_ids))
geo_df = GeoDataFrame({"geometry": geoms}, index=index.values)

scale = Scale([1.0 / specs["pixel_size"], 1.0 / specs["pixel_size"]], axes=("x", "y"))
return ShapesModel.parse(geo_df, transformations={"global": scale})
Expand All @@ -459,16 +512,15 @@ def _get_labels_and_indices_mapping(
specs: dict[str, Any],
mask_index: int,
labels_name: str,
cells_zarr: zarr.Group,
cell_id_str: ArrayLike,
labels_models_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> tuple[GeoDataFrame, pd.DataFrame | None]:
if mask_index not in [0, 1]:
raise ValueError(f"mask_index must be 0 or 1, found {mask_index}.")

zip_file = path / XeniumKeys.CELLS_ZARR
store = zarr.storage.ZipStore(zip_file, read_only=True)
z = zarr.open(store, mode="r")
# get the labels
masks = da.from_array(z["masks"][f"{mask_index}"])
masks = da.from_array(cells_zarr["masks"][f"{mask_index}"])
labels = Labels2DModel.parse(masks, dims=("y", "x"), transformations={"global": Identity()}, **labels_models_kwargs)

# build the matching table
Expand All @@ -481,11 +533,8 @@ def _get_labels_and_indices_mapping(
# supported in versions < 1.3.0
return labels, None

cell_id, dataset_suffix = z["cell_id"][...].T
cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id, dataset_suffix)

if version < packaging.version.parse("2.0.0"):
label_index = z["seg_mask_value"][...]
label_index = cells_zarr["seg_mask_value"][...]
else:
# For v >= 2.0.0, seg_mask_value is no longer available in the zarr;
# read label_id from the corresponding parquet boundary file instead
Expand Down Expand Up @@ -515,42 +564,29 @@ def _get_labels_and_indices_mapping(
"label_index": label_index.astype(np.int64),
}
)
# because AnnData converts the indices to str
indices_mapping.index = indices_mapping.index.astype(str)
return labels, indices_mapping


@inject_docs(xx=XeniumKeys)
def _get_cells_metadata_table_from_zarr(
path: Path,
file: str,
cells_zarr: zarr.Group,
specs: dict[str, Any],
cell_id_str: ArrayLike,
) -> AnnData:
"""Read cells metadata from ``{xx.CELLS_ZARR}``.

Read the cells summary table, which contains the z_level information for versions < 2.0.0, and also the
nucleus_count for versions >= 2.0.0.
"""
# for version >= 2.0.0, in this function we could also parse the segmentation method used to obtain the masks
zip_file = path / XeniumKeys.CELLS_ZARR
store = zarr.storage.ZipStore(zip_file, read_only=True)

z = zarr.open(store, mode="r")
x = z["cell_summary"][...]
column_names = z["cell_summary"].attrs["column_names"]
x = cells_zarr["cell_summary"][...]
column_names = cells_zarr["cell_summary"].attrs["column_names"]
df = pd.DataFrame(x, columns=column_names)
cell_id_prefix = z["cell_id"][:, 0]
dataset_suffix = z["cell_id"][:, 1]
store.close()

cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id_prefix, dataset_suffix)
df[XeniumKeys.CELL_ID] = cell_id_str
# because AnnData converts the indices to str
df.index = df.index.astype(str)
return df


def _get_points(path: Path, specs: dict[str, Any]) -> Table:
def _get_points(path: Path, specs: dict[str, Any]) -> pa.Table:
table = read_parquet(path / XeniumKeys.TRANSCRIPTS_FILE)

# check if we need to decode bytes
Expand Down Expand Up @@ -592,10 +628,12 @@ def _get_tables_and_circles(
) -> AnnData | tuple[AnnData, AnnData]:
adata = _read_10x_h5(path / XeniumKeys.CELL_FEATURE_MATRIX_FILE, gex_only=gex_only)
metadata = pd.read_parquet(path / XeniumKeys.CELL_METADATA_FILE)
np.testing.assert_array_equal(metadata.cell_id.astype(str), adata.obs_names.values)
_assert_arrays_equal_sampled(metadata.cell_id.astype(str), adata.obs_names.values)
circ = metadata[[XeniumKeys.CELL_X, XeniumKeys.CELL_Y]].to_numpy()
adata.obsm["spatial"] = circ
metadata.drop([XeniumKeys.CELL_X, XeniumKeys.CELL_Y], axis=1, inplace=True)
# avoids anndata's ImplicitModificationWarning
metadata.index = adata.obs_names
adata.obs = metadata
adata.obs["region"] = specs["region"]
adata.obs["region"] = adata.obs["region"].astype("category")
Expand Down Expand Up @@ -850,13 +888,18 @@ def _parse_version_of_xenium_analyzer(
return None


def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
# explained here:
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
def _cell_id_str_from_prefix_suffix_uint32_reference(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
"""Reference implementation of cell_id_str_from_prefix_suffix_uint32.

Readable but slow for large arrays due to Python-level string operations.
Kept as ground truth for testing the optimized version.

See https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
"""
# convert to hex, remove the 0x prefix
cell_id_prefix_hex = [hex(x)[2:] for x in cell_id_prefix]

# shift the hex values
# shift the hex values: '0'->'a', ..., '9'->'j', 'a'->'k', ..., 'f'->'p'
hex_shift = {str(i): chr(ord("a") + i) for i in range(10)} | {
chr(ord("a") + i): chr(ord("a") + 10 + i) for i in range(6)
}
Expand All @@ -870,6 +913,31 @@ def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suf
return np.array(cell_id_str)


def cell_id_str_from_prefix_suffix_uint32(cell_id_prefix: ArrayLike, dataset_suffix: ArrayLike) -> ArrayLike:
"""Convert cell ID prefix/suffix uint32 pairs to the Xenium string representation.

Each uint32 prefix is converted to 8 hex nibbles, each mapped to a character
(0->'a', 1->'b', ..., 15->'p'), then joined with "-{suffix}".

See https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/analysis/xoa-output-zarr#cellID
"""
cell_id_prefix = np.asarray(cell_id_prefix, dtype=np.uint32)
dataset_suffix = np.asarray(dataset_suffix)

# Extract 8 hex nibbles (4 bits each) from each uint32, most significant first.
# Each nibble maps to a character: 0->'a', 1->'b', ..., 9->'j', 10->'k', ..., 15->'p'.
# Leading zero nibbles become 'a', equivalent to rjust(8, 'a') padding.
shifts = np.array([28, 24, 20, 16, 12, 8, 4, 0], dtype=np.uint32)
nibbles = (cell_id_prefix[:, np.newaxis] >> shifts) & 0xF
char_codes = (nibbles + ord("a")).astype(np.uint8)

# View the (n, 8) uint8 array as n byte-strings of length 8
prefix_strs = char_codes.view("S8").ravel().astype("U8")

suffix_strs = np.char.add("-", dataset_suffix.astype("U"))
return np.char.add(prefix_strs, suffix_strs)


def prefix_suffix_uint32_from_cell_id_str(
cell_id_str: ArrayLike,
) -> tuple[ArrayLike, ArrayLike]:
Expand Down
Loading
Loading