Skip to content

Commit 1b9b577

Browse files
Added error handling for non-existent elements (#305)
2 parents 91e5726 + 8ed8463 commit 1b9b577

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning][].
2222
- Performance bug when plotting shapes (#298)
2323
- scale parameter was ignored for single-scale images (#301)
2424
- Changes to support for dask-expr (#283)
25+
- Added error handling for non-existent elements (#305)
2526

2627
## [0.2.3] - 2024-07-03
2728

src/spatialdata_plot/pl/utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,12 +1472,12 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
14721472

14731473
if element_type == "images":
14741474
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys())
1475-
if element_type == "labels":
1475+
elif element_type == "labels":
14761476
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].labels.keys())
1477-
if element_type == "shapes":
1478-
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys())
1479-
if element_type == "points":
1477+
elif element_type == "points":
14801478
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].points.keys())
1479+
elif element_type == "shapes":
1480+
param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys())
14811481

14821482
if (channel := param_dict.get("channel")) is not None and not isinstance(channel, (list, str, int)):
14831483
raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.")
@@ -1493,10 +1493,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
14931493
if (contour_px := param_dict.get("contour_px")) and not isinstance(contour_px, int):
14941494
raise TypeError("Parameter 'contour_px' must be an integer.")
14951495

1496-
if (color := param_dict.get("color")) and element_type in ["shapes", "points", "labels"]:
1496+
if (color := param_dict.get("color")) and element_type in {
1497+
"shapes",
1498+
"points",
1499+
"labels",
1500+
}:
14971501
if not isinstance(color, str):
14981502
raise TypeError("Parameter 'color' must be a string.")
1499-
if element_type in ["shapes", "points"]:
1503+
if element_type in {"shapes", "points"}:
15001504
if colors.is_color_like(color):
15011505
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
15021506
param_dict["col_for_color"] = None
@@ -1667,6 +1671,10 @@ def _validate_label_render_params(
16671671

16681672
element_params: dict[str, dict[str, Any]] = {}
16691673
for el in param_dict["element"]:
1674+
1675+
# ensure that the element exists in the SpatialData object
1676+
_ = param_dict["sdata"][el]
1677+
16701678
element_params[el] = {}
16711679
element_params[el]["na_color"] = param_dict["na_color"]
16721680
element_params[el]["cmap"] = param_dict["cmap"]
@@ -1721,6 +1729,10 @@ def _validate_points_render_params(
17211729

17221730
element_params: dict[str, dict[str, Any]] = {}
17231731
for el in param_dict["element"]:
1732+
1733+
# ensure that the element exists in the SpatialData object
1734+
_ = param_dict["sdata"][el]
1735+
17241736
element_params[el] = {}
17251737
element_params[el]["na_color"] = param_dict["na_color"]
17261738
element_params[el]["cmap"] = param_dict["cmap"]
@@ -1784,6 +1796,10 @@ def _validate_shape_render_params(
17841796

17851797
element_params: dict[str, dict[str, Any]] = {}
17861798
for el in param_dict["element"]:
1799+
1800+
# ensure that the element exists in the SpatialData object
1801+
_ = param_dict["sdata"][el]
1802+
17871803
element_params[el] = {}
17881804
element_params[el]["fill_alpha"] = param_dict["fill_alpha"]
17891805
element_params[el]["na_color"] = param_dict["na_color"]

tests/pl/test_render.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,35 @@ def test_render_images_can_plot_multiple_cyx_images(share_coordinate_system: str
2828

2929
if share_coordinate_system == "all":
3030
assert len(axs) == 1
31-
32-
if share_coordinate_system == "two":
31+
elif share_coordinate_system == "none":
32+
assert len(axs) == 3
33+
elif share_coordinate_system == "two":
3334
assert len(axs) == 2
3435

35-
if share_coordinate_system == "none":
36-
assert len(axs) == 3
36+
37+
def test_keyerror_when_image_element_does_not_exist(request):
38+
sdata = request.getfixturevalue("sdata_blobs")
39+
40+
with pytest.raises(KeyError):
41+
sdata.pl.render_images(element="not_found").pl.show()
42+
43+
44+
def test_keyerror_when_label_element_does_not_exist(request):
45+
sdata = request.getfixturevalue("sdata_blobs")
46+
47+
with pytest.raises(KeyError):
48+
sdata.pl.render_labels(element="not_found").pl.show()
49+
50+
51+
def test_keyerror_when_point_element_does_not_exist(request):
52+
sdata = request.getfixturevalue("sdata_blobs")
53+
54+
with pytest.raises(KeyError):
55+
sdata.pl.render_points(element="not_found").pl.show()
56+
57+
58+
def test_keyerror_when_shape_element_does_not_exist(request):
59+
sdata = request.getfixturevalue("sdata_blobs")
60+
61+
with pytest.raises(KeyError):
62+
sdata.pl.render_shapes(element="not_found").pl.show()

0 commit comments

Comments
 (0)