|
2 | 2 |
|
3 | 3 | from typing import Any
|
4 | 4 |
|
| 5 | +import numba as nb |
| 6 | +import numpy as np |
5 | 7 | from anndata import AnnData
|
| 8 | +from datatree import DataTree |
6 | 9 | from xarray import DataArray
|
7 | 10 |
|
8 | 11 | from spatialdata._core._elements import Tables
|
9 | 12 | from spatialdata._core.spatialdata import SpatialData
|
10 | 13 | from spatialdata._types import ArrayLike
|
11 | 14 | from spatialdata._utils import Number, _parse_list_into_array
|
| 15 | +from spatialdata.transformations._utils import compute_coordinates |
| 16 | +from spatialdata.transformations.transformations import ( |
| 17 | + BaseTransformation, |
| 18 | + Sequence, |
| 19 | + Translation, |
| 20 | +) |
12 | 21 |
|
13 | 22 |
|
14 | 23 | def get_bounding_box_corners(
|
@@ -36,37 +45,146 @@ def get_bounding_box_corners(
|
36 | 45 | min_coordinate = _parse_list_into_array(min_coordinate)
|
37 | 46 | max_coordinate = _parse_list_into_array(max_coordinate)
|
38 | 47 |
|
39 |
| - if len(min_coordinate) not in (2, 3): |
| 48 | + if min_coordinate.ndim == 1: |
| 49 | + min_coordinate = min_coordinate[np.newaxis, :] |
| 50 | + max_coordinate = max_coordinate[np.newaxis, :] |
| 51 | + |
| 52 | + if min_coordinate.shape[1] not in (2, 3): |
40 | 53 | raise ValueError("bounding box must be 2D or 3D")
|
41 | 54 |
|
42 |
| - if len(min_coordinate) == 2: |
| 55 | + num_boxes = min_coordinate.shape[0] |
| 56 | + num_dims = min_coordinate.shape[1] |
| 57 | + |
| 58 | + if num_dims == 2: |
43 | 59 | # 2D bounding box
|
44 | 60 | assert len(axes) == 2
|
45 |
| - return DataArray( |
| 61 | + corners = np.array( |
46 | 62 | [
|
47 |
| - [min_coordinate[0], min_coordinate[1]], |
48 |
| - [min_coordinate[0], max_coordinate[1]], |
49 |
| - [max_coordinate[0], max_coordinate[1]], |
50 |
| - [max_coordinate[0], min_coordinate[1]], |
51 |
| - ], |
52 |
| - coords={"corner": range(4), "axis": list(axes)}, |
| 63 | + [min_coordinate[:, 0], min_coordinate[:, 1]], |
| 64 | + [min_coordinate[:, 0], max_coordinate[:, 1]], |
| 65 | + [max_coordinate[:, 0], max_coordinate[:, 1]], |
| 66 | + [max_coordinate[:, 0], min_coordinate[:, 1]], |
| 67 | + ] |
53 | 68 | )
|
54 |
| - |
55 |
| - # 3D bounding cube |
56 |
| - assert len(axes) == 3 |
57 |
| - return DataArray( |
58 |
| - [ |
59 |
| - [min_coordinate[0], min_coordinate[1], min_coordinate[2]], |
60 |
| - [min_coordinate[0], min_coordinate[1], max_coordinate[2]], |
61 |
| - [min_coordinate[0], max_coordinate[1], max_coordinate[2]], |
62 |
| - [min_coordinate[0], max_coordinate[1], min_coordinate[2]], |
63 |
| - [max_coordinate[0], min_coordinate[1], min_coordinate[2]], |
64 |
| - [max_coordinate[0], min_coordinate[1], max_coordinate[2]], |
65 |
| - [max_coordinate[0], max_coordinate[1], max_coordinate[2]], |
66 |
| - [max_coordinate[0], max_coordinate[1], min_coordinate[2]], |
67 |
| - ], |
68 |
| - coords={"corner": range(8), "axis": list(axes)}, |
| 69 | + corners = np.transpose(corners, (2, 0, 1)) |
| 70 | + else: |
| 71 | + # 3D bounding cube |
| 72 | + assert len(axes) == 3 |
| 73 | + corners = np.array( |
| 74 | + [ |
| 75 | + [min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], |
| 76 | + [min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], |
| 77 | + [min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], |
| 78 | + [min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], |
| 79 | + [max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], |
| 80 | + [max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], |
| 81 | + [max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], |
| 82 | + [max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], |
| 83 | + ] |
| 84 | + ) |
| 85 | + corners = np.transpose(corners, (2, 0, 1)) |
| 86 | + output = DataArray( |
| 87 | + corners, |
| 88 | + coords={ |
| 89 | + "box": range(num_boxes), |
| 90 | + "corner": range(corners.shape[1]), |
| 91 | + "axis": list(axes), |
| 92 | + }, |
69 | 93 | )
|
| 94 | + if num_boxes > 1: |
| 95 | + return output |
| 96 | + return output.squeeze().drop_vars("box") |
| 97 | + |
| 98 | + |
| 99 | +@nb.njit(parallel=False, nopython=True) |
| 100 | +def _create_slices_and_translation( |
| 101 | + min_values: nb.types.Array, |
| 102 | + max_values: nb.types.Array, |
| 103 | +) -> tuple[nb.types.Array, nb.types.Array]: |
| 104 | + n_boxes, n_dims = min_values.shape |
| 105 | + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) |
| 106 | + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) |
| 107 | + |
| 108 | + for i in range(n_boxes): |
| 109 | + for j in range(n_dims): |
| 110 | + slices[i, j, 0] = min_values[i, j] |
| 111 | + slices[i, j, 1] = max_values[i, j] |
| 112 | + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) |
| 113 | + |
| 114 | + return slices, translation_vectors |
| 115 | + |
| 116 | + |
| 117 | +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: |
| 118 | + d = {} |
| 119 | + for k, data_tree in query_result.items(): |
| 120 | + v = data_tree.values() |
| 121 | + assert len(v) == 1 |
| 122 | + xdata = v.__iter__().__next__() |
| 123 | + if 0 in xdata.shape: |
| 124 | + if k == "scale0": |
| 125 | + return None |
| 126 | + else: |
| 127 | + d[k] = xdata |
| 128 | + |
| 129 | + # Remove scales after finding a missing scale |
| 130 | + scales_to_keep = [] |
| 131 | + for i, scale_name in enumerate(d.keys()): |
| 132 | + if scale_name == f"scale{i}": |
| 133 | + scales_to_keep.append(scale_name) |
| 134 | + else: |
| 135 | + break |
| 136 | + |
| 137 | + # Case in which scale0 is not present but other scales are |
| 138 | + if len(scales_to_keep) == 0: |
| 139 | + return None |
| 140 | + |
| 141 | + d = {k: d[k] for k in scales_to_keep} |
| 142 | + result = DataTree.from_dict(d) |
| 143 | + |
| 144 | + # Rechunk the data to avoid irregular chunks |
| 145 | + for scale in result: |
| 146 | + result[scale]["image"] = result[scale]["image"].chunk("auto") |
| 147 | + |
| 148 | + return result |
| 149 | + |
| 150 | + |
| 151 | +def _process_query_result( |
| 152 | + result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] |
| 153 | +) -> DataArray | DataTree | None: |
| 154 | + from spatialdata.transformations import get_transformation, set_transformation |
| 155 | + |
| 156 | + if isinstance(result, DataArray): |
| 157 | + if 0 in result.shape: |
| 158 | + return None |
| 159 | + # rechunk the data to avoid irregular chunks |
| 160 | + result = result.chunk("auto") |
| 161 | + elif isinstance(result, DataTree): |
| 162 | + result = _process_data_tree_query_result(result) |
| 163 | + if result is None: |
| 164 | + return None |
| 165 | + |
| 166 | + result = compute_coordinates(result) |
| 167 | + |
| 168 | + if not np.allclose(np.array(translation_vector), 0): |
| 169 | + translation_transform = Translation(translation=translation_vector, axes=axes) |
| 170 | + |
| 171 | + transformations = get_transformation(result, get_all=True) |
| 172 | + assert isinstance(transformations, dict) |
| 173 | + |
| 174 | + new_transformations = {} |
| 175 | + for coordinate_system, initial_transform in transformations.items(): |
| 176 | + new_transformation: BaseTransformation = Sequence( |
| 177 | + [translation_transform, initial_transform], |
| 178 | + ) |
| 179 | + new_transformations[coordinate_system] = new_transformation |
| 180 | + set_transformation(result, new_transformations, set_all=True) |
| 181 | + |
| 182 | + # let's make a copy of the transformations so that we don't modify the original object |
| 183 | + t = get_transformation(result, get_all=True) |
| 184 | + assert isinstance(t, dict) |
| 185 | + set_transformation(result, t.copy(), set_all=True) |
| 186 | + |
| 187 | + return result |
70 | 188 |
|
71 | 189 |
|
72 | 190 | def _get_filtered_or_unfiltered_tables(
|
|
0 commit comments