Skip to content

Commit 8239455

Browse files
giovpLucaMarconato
andauthored
vectorize bounding box query (#699)
* vectorize adjust_bounding_box_to_real_axes * update * replace append with insert * add comment * vectorize * update to handle multiple boxes * vectorize with numba * fix corner len * fix validation * refactor * refactor * add test for query with multiple bounding boxes * fix typing * vectorize bounding box query on polygons * add test to cover no polygon overlap (None) * vectorize bounding box query on points and tests * fix type * wip fixes code review * added extra test; finished applying code review changes --------- Co-authored-by: Luca Marconato <m.lucalmer@gmail.com>
1 parent 8879aff commit 8239455

File tree

4 files changed

+515
-224
lines changed

4 files changed

+515
-224
lines changed

src/spatialdata/_core/query/_utils.py

Lines changed: 142 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,22 @@
22

33
from typing import Any
44

5+
import numba as nb
6+
import numpy as np
57
from anndata import AnnData
8+
from datatree import DataTree
69
from xarray import DataArray
710

811
from spatialdata._core._elements import Tables
912
from spatialdata._core.spatialdata import SpatialData
1013
from spatialdata._types import ArrayLike
1114
from spatialdata._utils import Number, _parse_list_into_array
15+
from spatialdata.transformations._utils import compute_coordinates
16+
from spatialdata.transformations.transformations import (
17+
BaseTransformation,
18+
Sequence,
19+
Translation,
20+
)
1221

1322

1423
def get_bounding_box_corners(
@@ -36,37 +45,146 @@ def get_bounding_box_corners(
3645
min_coordinate = _parse_list_into_array(min_coordinate)
3746
max_coordinate = _parse_list_into_array(max_coordinate)
3847

39-
if len(min_coordinate) not in (2, 3):
48+
if min_coordinate.ndim == 1:
49+
min_coordinate = min_coordinate[np.newaxis, :]
50+
max_coordinate = max_coordinate[np.newaxis, :]
51+
52+
if min_coordinate.shape[1] not in (2, 3):
4053
raise ValueError("bounding box must be 2D or 3D")
4154

42-
if len(min_coordinate) == 2:
55+
num_boxes = min_coordinate.shape[0]
56+
num_dims = min_coordinate.shape[1]
57+
58+
if num_dims == 2:
4359
# 2D bounding box
4460
assert len(axes) == 2
45-
return DataArray(
61+
corners = np.array(
4662
[
47-
[min_coordinate[0], min_coordinate[1]],
48-
[min_coordinate[0], max_coordinate[1]],
49-
[max_coordinate[0], max_coordinate[1]],
50-
[max_coordinate[0], min_coordinate[1]],
51-
],
52-
coords={"corner": range(4), "axis": list(axes)},
63+
[min_coordinate[:, 0], min_coordinate[:, 1]],
64+
[min_coordinate[:, 0], max_coordinate[:, 1]],
65+
[max_coordinate[:, 0], max_coordinate[:, 1]],
66+
[max_coordinate[:, 0], min_coordinate[:, 1]],
67+
]
5368
)
54-
55-
# 3D bounding cube
56-
assert len(axes) == 3
57-
return DataArray(
58-
[
59-
[min_coordinate[0], min_coordinate[1], min_coordinate[2]],
60-
[min_coordinate[0], min_coordinate[1], max_coordinate[2]],
61-
[min_coordinate[0], max_coordinate[1], max_coordinate[2]],
62-
[min_coordinate[0], max_coordinate[1], min_coordinate[2]],
63-
[max_coordinate[0], min_coordinate[1], min_coordinate[2]],
64-
[max_coordinate[0], min_coordinate[1], max_coordinate[2]],
65-
[max_coordinate[0], max_coordinate[1], max_coordinate[2]],
66-
[max_coordinate[0], max_coordinate[1], min_coordinate[2]],
67-
],
68-
coords={"corner": range(8), "axis": list(axes)},
69+
corners = np.transpose(corners, (2, 0, 1))
70+
else:
71+
# 3D bounding cube
72+
assert len(axes) == 3
73+
corners = np.array(
74+
[
75+
[min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]],
76+
[min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]],
77+
[min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]],
78+
[min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]],
79+
[max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]],
80+
[max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]],
81+
[max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]],
82+
[max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]],
83+
]
84+
)
85+
corners = np.transpose(corners, (2, 0, 1))
86+
output = DataArray(
87+
corners,
88+
coords={
89+
"box": range(num_boxes),
90+
"corner": range(corners.shape[1]),
91+
"axis": list(axes),
92+
},
6993
)
94+
if num_boxes > 1:
95+
return output
96+
return output.squeeze().drop_vars("box")
97+
98+
99+
@nb.njit(parallel=False, nopython=True)
100+
def _create_slices_and_translation(
101+
min_values: nb.types.Array,
102+
max_values: nb.types.Array,
103+
) -> tuple[nb.types.Array, nb.types.Array]:
104+
n_boxes, n_dims = min_values.shape
105+
slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max])
106+
translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims)
107+
108+
for i in range(n_boxes):
109+
for j in range(n_dims):
110+
slices[i, j, 0] = min_values[i, j]
111+
slices[i, j, 1] = max_values[i, j]
112+
translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0))
113+
114+
return slices, translation_vectors
115+
116+
117+
def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None:
118+
d = {}
119+
for k, data_tree in query_result.items():
120+
v = data_tree.values()
121+
assert len(v) == 1
122+
xdata = v.__iter__().__next__()
123+
if 0 in xdata.shape:
124+
if k == "scale0":
125+
return None
126+
else:
127+
d[k] = xdata
128+
129+
# Remove scales after finding a missing scale
130+
scales_to_keep = []
131+
for i, scale_name in enumerate(d.keys()):
132+
if scale_name == f"scale{i}":
133+
scales_to_keep.append(scale_name)
134+
else:
135+
break
136+
137+
# Case in which scale0 is not present but other scales are
138+
if len(scales_to_keep) == 0:
139+
return None
140+
141+
d = {k: d[k] for k in scales_to_keep}
142+
result = DataTree.from_dict(d)
143+
144+
# Rechunk the data to avoid irregular chunks
145+
for scale in result:
146+
result[scale]["image"] = result[scale]["image"].chunk("auto")
147+
148+
return result
149+
150+
151+
def _process_query_result(
152+
result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...]
153+
) -> DataArray | DataTree | None:
154+
from spatialdata.transformations import get_transformation, set_transformation
155+
156+
if isinstance(result, DataArray):
157+
if 0 in result.shape:
158+
return None
159+
# rechunk the data to avoid irregular chunks
160+
result = result.chunk("auto")
161+
elif isinstance(result, DataTree):
162+
result = _process_data_tree_query_result(result)
163+
if result is None:
164+
return None
165+
166+
result = compute_coordinates(result)
167+
168+
if not np.allclose(np.array(translation_vector), 0):
169+
translation_transform = Translation(translation=translation_vector, axes=axes)
170+
171+
transformations = get_transformation(result, get_all=True)
172+
assert isinstance(transformations, dict)
173+
174+
new_transformations = {}
175+
for coordinate_system, initial_transform in transformations.items():
176+
new_transformation: BaseTransformation = Sequence(
177+
[translation_transform, initial_transform],
178+
)
179+
new_transformations[coordinate_system] = new_transformation
180+
set_transformation(result, new_transformations, set_all=True)
181+
182+
# let's make a copy of the transformations so that we don't modify the original object
183+
t = get_transformation(result, get_all=True)
184+
assert isinstance(t, dict)
185+
set_transformation(result, t.copy(), set_all=True)
186+
187+
return result
70188

71189

72190
def _get_filtered_or_unfiltered_tables(

0 commit comments

Comments
 (0)