|
5 | 5 | import pandas as pd
|
6 | 6 | import pytest
|
7 | 7 | import scanpy as sc
|
| 8 | +from anndata import AnnData |
8 | 9 | from spatial_image import to_spatial_image
|
9 |
| -from spatialdata import SpatialData, deepcopy |
| 10 | +from spatialdata import SpatialData, deepcopy, get_element_instances |
10 | 11 | from spatialdata.models import TableModel
|
11 | 12 |
|
12 | 13 | import spatialdata_plot # noqa: F401
|
@@ -208,3 +209,27 @@ def test_plot_subset_categorical_label_maintains_order_when_palette_overwrite(se
|
208 | 209 | sdata_blobs.pl.render_labels(
|
209 | 210 | "blobs_labels", color="which_max", groups=["channel_0_sum"], palette="red"
|
210 | 211 | ).pl.show(ax=axs[1])
|
| 212 | + |
| 213 | + def test_plot_label_categorical_color(self, sdata_blobs: SpatialData): |
| 214 | + self._make_tablemodel_with_categorical_labels(sdata_blobs, labels_name="blobs_labels") |
| 215 | + sdata_blobs.pl.render_labels("blobs_labels", color="category").pl.show() |
| 216 | + |
| 217 | + def _make_tablemodel_with_categorical_labels(self, sdata_blobs, labels_name: str): |
| 218 | + instances = get_element_instances(sdata_blobs[labels_name]) |
| 219 | + n_obs = len(instances) |
| 220 | + adata = AnnData( |
| 221 | + RNG.normal(size=(n_obs, 10)), |
| 222 | + obs=pd.DataFrame(RNG.normal(size=(n_obs, 3)), columns=["a", "b", "c"]), |
| 223 | + ) |
| 224 | + adata.obs["instance_id"] = instances.values |
| 225 | + adata.obs["category"] = RNG.choice(["a", "b", "c"], size=adata.n_obs) |
| 226 | + adata.obs["category"][:3] = ["a", "b", "c"] |
| 227 | + adata.obs["region"] = labels_name |
| 228 | + table = TableModel.parse( |
| 229 | + adata=adata, |
| 230 | + region_key="region", |
| 231 | + instance_key="instance_id", |
| 232 | + region=labels_name, |
| 233 | + ) |
| 234 | + sdata_blobs["other_table"] = table |
| 235 | + sdata_blobs["other_table"].obs["category"] = sdata_blobs["other_table"].obs["category"].astype("category") |
0 commit comments