Skip to content

Commit 59d759a

Browse files
Minor cleanups to rois API (#112)
* improve roi repr * make pixel size a pydantic model * polish roi to slicing dict api * update change log * implement roi union * Roi simplification * add union testing * Revert "implement roi union" This reverts commit 360275f. * improve handling of sequence slicing * improve um to pixel conversion * update changelog
1 parent cf2145c commit 59d759a

File tree

7 files changed

+107
-131
lines changed

7 files changed

+107
-131
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# Changelog
22

3+
## [v0.4.2]
4+
5+
### API Changes
6+
7+
- Make roi.to_slicing_dict(pixel_size) always require pixel_size argument for consistency with other roi methods.
8+
- Make PixelSize object a Pydantic model to allow for serialization.
9+
10+
### Bug Fixes
11+
12+
- Improve robustness when rounding Rois to pixel coordinates.
13+
314
## [v0.4.1]
415

516
### Bug Fixes

src/ngio/common/_roi.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,22 @@
1414
from ngio.utils import NgioValueError
1515

1616

17-
def _to_raster(value: float, length: float, pixel_size: float) -> tuple[float, float]:
17+
def _world_to_raster(value: float, pixel_size: float, eps: float = 1e-6) -> float:
1818
raster_value = value / pixel_size
19-
raster_length = length / pixel_size
19+
20+
# If the value is very close to an integer, round it
21+
# This ensures that we don't have floating point precision issues
22+
# When loading ROIs that were originally defined in pixel coordinates
23+
_rounded = round(raster_value)
24+
if abs(_rounded - raster_value) < eps:
25+
return _rounded
26+
return raster_value
27+
28+
29+
def _to_raster(value: float, length: float, pixel_size: float) -> tuple[float, float]:
30+
"""Convert to raster coordinates."""
31+
raster_value = _world_to_raster(value, pixel_size)
32+
raster_length = _world_to_raster(length, pixel_size)
2033
return raster_value, raster_length
2134

2235

@@ -29,7 +42,7 @@ def _to_slice(start: float | None, length: float | None) -> slice:
2942
return slice(start, end)
3043

3144

32-
def _to_world(value: int | float, pixel_size: float) -> float:
45+
def _raster_to_world(value: int | float, pixel_size: float) -> float:
3346
"""Convert to world coordinates."""
3447
return value * pixel_size
3548

@@ -60,16 +73,28 @@ def intersection(self, other: "GenericRoi") -> "GenericRoi | None":
6073

6174
def _nice_str(self) -> str:
6275
if self.t is not None:
63-
t_str = f"t={self.t}->{self.t_length}"
76+
t_start = self.t
77+
else:
78+
t_start = None
79+
if self.t_length is not None and t_start is not None:
80+
t_end = t_start + self.t_length
6481
else:
65-
t_str = "t=None"
82+
t_end = None
83+
84+
t_str = f"t={t_start}->{t_end}"
85+
6686
if self.z is not None:
67-
z_str = f"z={self.z}->{self.z_length}"
87+
z_start = self.z
88+
else:
89+
z_start = None
90+
if self.z_length is not None and z_start is not None:
91+
z_end = z_start + self.z_length
6892
else:
69-
z_str = "z=None"
93+
z_end = None
94+
z_str = f"z={z_start}->{z_end}"
7095

71-
y_str = f"y={self.y}->{self.y_length}"
72-
x_str = f"x={self.x}->{self.x_length}"
96+
y_str = f"y={self.y}->{self.y + self.y_length}"
97+
x_str = f"x={self.x}->{self.x + self.x_length}"
7398

7499
if self.label is not None:
75100
label_str = f", label={self.label}"
@@ -90,6 +115,9 @@ def __repr__(self) -> str:
90115
def __str__(self) -> str:
91116
return self._nice_str()
92117

118+
def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]:
119+
raise NotImplementedError
120+
93121

94122
def _1d_intersection(
95123
a: T | None, a_length: T | None, b: T | None, b_length: T | None
@@ -251,6 +279,11 @@ def zoom(self, zoom_factor: float = 1) -> "Roi":
251279
"""
252280
return zoom_roi(self, zoom_factor)
253281

282+
def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]:
283+
"""Convert to a slicing dictionary."""
284+
roi_pixels = self.to_roi_pixels(pixel_size)
285+
return roi_pixels.to_slicing_dict(pixel_size)
286+
254287

255288
class RoiPixels(GenericRoi):
256289
"""Region of interest (ROI) in pixel coordinates."""
@@ -261,30 +294,30 @@ class RoiPixels(GenericRoi):
261294

262295
def to_roi(self, pixel_size: PixelSize) -> "Roi":
263296
"""Convert to raster coordinates."""
264-
x = _to_world(self.x, pixel_size.x)
265-
x_length = _to_world(self.x_length, pixel_size.x)
266-
y = _to_world(self.y, pixel_size.y)
267-
y_length = _to_world(self.y_length, pixel_size.y)
297+
x = _raster_to_world(self.x, pixel_size.x)
298+
x_length = _raster_to_world(self.x_length, pixel_size.x)
299+
y = _raster_to_world(self.y, pixel_size.y)
300+
y_length = _raster_to_world(self.y_length, pixel_size.y)
268301

269302
if self.z is None:
270303
z = None
271304
else:
272-
z = _to_world(self.z, pixel_size.z)
305+
z = _raster_to_world(self.z, pixel_size.z)
273306

274307
if self.z_length is None:
275308
z_length = None
276309
else:
277-
z_length = _to_world(self.z_length, pixel_size.z)
310+
z_length = _raster_to_world(self.z_length, pixel_size.z)
278311

279312
if self.t is None:
280313
t = None
281314
else:
282-
t = _to_world(self.t, pixel_size.t)
315+
t = _raster_to_world(self.t, pixel_size.t)
283316

284317
if self.t_length is None:
285318
t_length = None
286319
else:
287-
t_length = _to_world(self.t_length, pixel_size.t)
320+
t_length = _raster_to_world(self.t_length, pixel_size.t)
288321

289322
extra_dict = self.model_extra if self.model_extra else {}
290323
return Roi(
@@ -302,7 +335,7 @@ def to_roi(self, pixel_size: PixelSize) -> "Roi":
302335
**extra_dict,
303336
)
304337

305-
def to_slicing_dict(self) -> dict[str, slice]:
338+
def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]:
306339
"""Convert to a slicing dictionary."""
307340
x_slice = _to_slice(self.x, self.x_length)
308341
y_slice = _to_slice(self.y, self.y_length)

src/ngio/io_pipes/_io_pipes_roi.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,18 @@
1313
from ngio.io_pipes._ops_slices import SlicingInputType
1414
from ngio.io_pipes._ops_transforms import TransformProtocol
1515
from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize
16-
from ngio.utils import NgioValueError
1716

1817

1918
def roi_to_slicing_dict(
2019
*,
2120
roi: Roi | RoiPixels,
22-
pixel_size: PixelSize | None = None,
21+
pixel_size: PixelSize,
2322
slicing_dict: dict[str, SlicingInputType] | None = None,
2423
) -> dict[str, SlicingInputType]:
2524
"""Convert a ROI to a slicing dictionary."""
26-
if isinstance(roi, Roi):
27-
if pixel_size is None:
28-
raise NgioValueError(
29-
"pixel_size must be provided when converting a Roi to slice_kwargs."
30-
)
31-
roi = roi.to_roi_pixels(pixel_size=pixel_size)
32-
33-
roi_slicing_dict: dict[str, SlicingInputType] = roi.to_slicing_dict() # type: ignore
25+
roi_slicing_dict: dict[str, SlicingInputType] = roi.to_slicing_dict(
26+
pixel_size=pixel_size
27+
) # type: ignore
3428
if slicing_dict is None:
3529
return roi_slicing_dict
3630

src/ngio/io_pipes/_ops_slices.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ngio.utils import NgioValueError
1515

1616
SlicingInputType: TypeAlias = slice | Sequence[int] | int | None
17-
SlicingType: TypeAlias = slice | tuple[int, ...] | int
17+
SlicingType: TypeAlias = slice | list[int] | int
1818

1919
##############################################################
2020
#
@@ -60,7 +60,7 @@ def _slicing_tuple_boundary_check(
6060
elif isinstance(sl, int):
6161
_int_boundary_check(sl, shape=sh)
6262
out_slicing_tuple.append(sl)
63-
elif isinstance(sl, tuple):
63+
elif isinstance(sl, list):
6464
[_int_boundary_check(i, shape=sh) for i in sl]
6565
out_slicing_tuple.append(sl)
6666
else:
@@ -115,32 +115,31 @@ def get(self, ax_name: str, normalize: bool = False) -> SlicingType:
115115
return slicing_tuple[ax_index]
116116

117117

118-
def _check_tuple_in_slicing_tuple(
118+
def _check_list_in_slicing_tuple(
119119
slicing_tuple: tuple[SlicingType, ...],
120-
) -> tuple[None, None] | tuple[int, tuple[int, ...]]:
121-
"""Check if there are any tuple in the slicing tuple.
120+
) -> tuple[None, None] | tuple[int, list[int]]:
121+
"""Check if there are any lists in the slicing tuple.
122122
123-
The zarr python api only supports int or slices, not tuples.
124-
Ngio support a single tuple in the slicing tuple to allow non-contiguous
123+
Dask regions when setting data do not support non-contiguous
124+
selections natively.
125+
Ngio support a single list in the slicing tuple to allow non-contiguous
125126
selection (main use case: selecting multiple channels).
126127
"""
127-
# Find if the is any tuple in the slicing tuple
128+
# Find if the is any list in the slicing tuple
128129
# If there is one we need to handle it differently
129-
tuple_in_slice = [
130-
(i, s) for i, s in enumerate(slicing_tuple) if isinstance(s, tuple)
131-
]
132-
if not tuple_in_slice:
133-
# No tuple in the slicing tuple
130+
list_in_slice = [(i, s) for i, s in enumerate(slicing_tuple) if isinstance(s, list)]
131+
if not list_in_slice:
132+
# No list in the slicing tuple
134133
return None, None
135134

136-
if len(tuple_in_slice) > 1:
135+
if len(list_in_slice) > 1:
137136
raise NotImplementedError(
138137
"Slicing with multiple non-contiguous tuples/lists "
139138
"is not supported yet in Ngio. Use directly the "
140139
"zarr.Array api to get the correct array slice."
141140
)
142141
# Complex case, we have exactly one tuple in the slicing tuple
143-
ax, first_tuple = tuple_in_slice[0]
142+
ax, first_tuple = list_in_slice[0]
144143
if len(first_tuple) > 100:
145144
warn(
146145
"Performance warning: "
@@ -164,38 +163,14 @@ def get_slice_as_numpy(zarr_array: zarr.Array, slicing_ops: SlicingOps) -> np.nd
164163
slicing_tuple = slicing_ops.normalized_slicing_tuple
165164
# Find if the is any tuple in the slicing tuple
166165
# If there is one we need to handle it differently
167-
ax, first_tuple = _check_tuple_in_slicing_tuple(slicing_tuple)
168-
if ax is None:
169-
# Simple case, no tuple in the slicing tuple
170-
return zarr_array[slicing_tuple]
171-
172-
assert first_tuple is not None
173-
slices = [
174-
zarr_array[(*slicing_tuple[:ax], idx, *slicing_tuple[ax + 1 :])]
175-
for idx in first_tuple
176-
]
177-
out_array = np.stack(slices, axis=ax)
178-
return out_array
166+
return zarr_array[slicing_tuple]
179167

180168

181169
def get_slice_as_dask(zarr_array: zarr.Array, slicing_ops: SlicingOps) -> da.Array:
182170
"""Get a slice of a zarr array as a dask array."""
183171
da_array = da.from_zarr(zarr_array)
184172
slicing_tuple = slicing_ops.normalized_slicing_tuple
185-
# Find if the is any tuple in the slicing tuple
186-
# If there is one we need to handle it differently
187-
ax, first_tuple = _check_tuple_in_slicing_tuple(slicing_tuple)
188-
if ax is None:
189-
# Base case, no tuple in the slicing tuple
190-
return da_array[slicing_tuple]
191-
192-
assert first_tuple is not None
193-
slices = [
194-
da_array[(*slicing_tuple[:ax], idx, *slicing_tuple[ax + 1 :])]
195-
for idx in first_tuple
196-
]
197-
out_array = da.stack(slices, axis=ax)
198-
return out_array
173+
return da_array[slicing_tuple]
199174

200175

201176
def set_slice_as_numpy(
@@ -204,17 +179,7 @@ def set_slice_as_numpy(
204179
slicing_ops: SlicingOps,
205180
) -> None:
206181
slice_tuple = slicing_ops.normalized_slicing_tuple
207-
ax, first_tuple = _check_tuple_in_slicing_tuple(slice_tuple)
208-
if ax is None:
209-
# Base case, no tuple in the slicing tuple
210-
zarr_array[slice_tuple] = patch
211-
return
212-
213-
# Complex case, we have exactly one tuple in the slicing tuple
214-
assert first_tuple is not None
215-
for i, idx in enumerate(first_tuple):
216-
_sub_slice = (*slice_tuple[:ax], idx, *slice_tuple[ax + 1 :])
217-
zarr_array[_sub_slice] = np.take(patch, indices=i, axis=ax)
182+
zarr_array[slice_tuple] = patch
218183

219184

220185
def handle_int_set_as_dask(
@@ -237,7 +202,7 @@ def set_slice_as_dask(
237202
zarr_array: zarr.Array, patch: da.Array, slicing_ops: SlicingOps
238203
) -> None:
239204
slice_tuple = slicing_ops.normalized_slicing_tuple
240-
ax, first_tuple = _check_tuple_in_slicing_tuple(slice_tuple)
205+
ax, first_tuple = _check_list_in_slicing_tuple(slice_tuple)
241206
patch, slice_tuple = handle_int_set_as_dask(patch, slice_tuple)
242207
if ax is None:
243208
# Base case, no tuple in the slicing tuple
@@ -261,13 +226,13 @@ def set_slice_as_dask(
261226
##############################################################
262227

263228

264-
def _try_to_slice(value: Sequence[int]) -> slice | tuple[int, ...]:
229+
def _try_to_slice(value: Sequence[int]) -> slice | list[int]:
265230
"""Try to convert a list of integers into a slice if they are contiguous.
266231
267232
- If the input is empty, return an empty tuple.
268233
- If the input is sorted, and contains contiguous integers,
269234
return a slice from the minimum to the maximum integer.
270-
- Otherwise, return the input as a tuple.
235+
- Otherwise, return the input as a list of integers.
271236
272237
This is useful for optimizing array slicing operations
273238
by allowing the use of slices when possible, which can be more efficient.
@@ -293,7 +258,7 @@ def _try_to_slice(value: Sequence[int]) -> slice | tuple[int, ...]:
293258
if sorted(value) == list(range(min_input, max_input + 1)):
294259
return slice(min_input, max_input + 1)
295260

296-
return tuple(value)
261+
return list(value)
297262

298263

299264
def _remove_channel_slicing(
@@ -393,7 +358,7 @@ def _normalize_slicing_tuple(
393358
The output types are:
394359
- slice
395360
- int
396-
- tuple of int (for non-contiguous selection)
361+
- list of int (for non-contiguous selection)
397362
"""
398363
axis_name = axis.name
399364
if axis_name not in slicing_dict:
@@ -408,7 +373,7 @@ def _normalize_slicing_tuple(
408373
elif isinstance(value, Sequence):
409374
# If a contiguous sequence of integers is provided,
410375
# convert it to a slice for simplicity.
411-
# Alternatively, it will be converted to a tuple of ints
376+
# Alternatively, it will be converted to a list of ints
412377
return _try_to_slice(value)
413378

414379
raise NgioValueError(

src/ngio/io_pipes/_ops_slices_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ def _pairs_stream(iterable: Iterable[T]) -> Iterator[tuple[T, T]]:
2222
seen.append(a)
2323

2424

25-
SlicingType: TypeAlias = slice | tuple[int, ...] | int
25+
SlicingType: TypeAlias = slice | list[int] | int
2626

2727

2828
def check_elem_intersection(s1: SlicingType, s2: SlicingType) -> bool:
2929
"""Compare if two SlicingType elements intersect.
3030
3131
If they are a slice, check if they overlap.
3232
If they are integers, check if they are equal.
33-
If they are tuples, check if they have any common elements.
33+
If they are lists, check if they have any common elements.
3434
"""
3535
if not isinstance(s1, type(s2)):
3636
raise NgioValueError(
@@ -56,7 +56,7 @@ def check_elem_intersection(s1: SlicingType, s2: SlicingType) -> bool:
5656
elif isinstance(s1, int) and isinstance(s2, int):
5757
# Handle integer indices
5858
return s1 == s2
59-
elif isinstance(s1, tuple) and isinstance(s2, tuple):
59+
elif isinstance(s1, list) and isinstance(s2, list):
6060
if set(s1) & set(s2):
6161
return True
6262
return False
@@ -130,7 +130,7 @@ def _chunk_indices_for_axis(sel: SlicingType, size: int, csize: int) -> list[int
130130
raise IndexError(f"index {sel} out of bounds for axis of size {size}")
131131
return [sel // csize]
132132

133-
if isinstance(sel, tuple):
133+
if isinstance(sel, list):
134134
if not sel:
135135
return []
136136
chunks_hit = {}

0 commit comments

Comments
 (0)