Skip to content

Commit 7a41b3f

Browse files
committed
Adding in additional placeholders and functions required for functional, array and dictionary transforms to operate while waiting for PR #5436
Signed-off-by: Ben Murray <ben.murray@gmail.com>
1 parent b512281 commit 7a41b3f

File tree

2 files changed

+237
-1
lines changed

2 files changed

+237
-1
lines changed

monai/transforms/meta_matrix.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import torch
1616

1717
# placeholder to be replaced by MetaMatrix in Apply And Resample PR #5436
18-
from monai.transforms.utils import _create_rotate, _create_shear, _create_scale, _create_translate
18+
from monai.config import NdarrayOrTensor
19+
20+
from monai.transforms.utils import _create_rotate, _create_shear, _create_scale, _create_translate, _create_rotate_90, \
21+
_create_flip
1922

2023
from monai.utils import TransformBackends
2124

@@ -98,6 +101,179 @@ def translate(self, offsets: Union[Sequence[float], float], **extra_args):
98101
return MetaMatrix(matrix, extra_args)
99102

100103

104+
class Matrix:
105+
def __init__(self, matrix: NdarrayOrTensor):
106+
self.data = ensure_tensor(matrix)
107+
108+
# def __matmul__(self, other):
109+
# if isinstance(other, Matrix):
110+
# other_matrix = other.data
111+
# else:
112+
# other_matrix = other
113+
# return self.data @ other_matrix
114+
#
115+
# def __rmatmul__(self, other):
116+
# return other.__matmul__(self.data)
117+
118+
119+
class Grid:
120+
def __init__(self, grid):
121+
self.data = ensure_tensor(grid)
122+
123+
# def __matmul__(self, other):
124+
# raise NotImplementedError()
125+
126+
127+
class MetaMatrix:
128+
def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None):
129+
130+
if not isinstance(matrix, (Matrix, Grid)):
131+
if matrix.shape == 2:
132+
if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4):
133+
raise ValueError(
134+
"If 'matrix' is passed a numpy ndarray/torch Tensor, it must"
135+
f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})"
136+
)
137+
matrix_ = Matrix(matrix)
138+
else:
139+
matrix_ = matrix
140+
self.matrix = matrix_
141+
142+
self.metadata = metadata or {}
143+
144+
def __matmul__(self, other):
145+
if isinstance(other, MetaMatrix):
146+
other_ = other.matrix
147+
else:
148+
other_ = other
149+
return MetaMatrix(self.matrix @ other_)
150+
151+
def __rmatmul__(self, other):
152+
if isinstance(other, MetaMatrix):
153+
other_ = other.matrix
154+
else:
155+
other_ = other
156+
return MetaMatrix(other_ @ self.matrix)
157+
158+
159+
def matmul(
160+
left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor]
161+
):
162+
matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray)
163+
164+
if not isinstance(left, matrix_types):
165+
raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}")
166+
if not isinstance(right, matrix_types):
167+
raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}")
168+
169+
left_ = left.matrix if isinstance(left, MetaMatrix) else left
170+
right_ = right.matrix if isinstance(right, MetaMatrix) else right
171+
172+
# TODO: it might be better to not return a metamatrix, unless we pass in the resulting
173+
# metadata also
174+
put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix)
175+
176+
put_in_grid = isinstance(left, Grid) or isinstance(right, Grid)
177+
178+
put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix)
179+
put_in_matrix = False if put_in_grid is True else put_in_matrix
180+
181+
promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray))
182+
183+
left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_
184+
right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_
185+
186+
if promote_to_tensor:
187+
left_raw = torch.as_tensor(left_raw)
188+
right_raw = torch.as_tensor(right_raw)
189+
190+
if isinstance(left_, Grid):
191+
if isinstance(right_, Grid):
192+
raise RuntimeError("Unable to matrix multiply two Grids")
193+
else:
194+
result = matmul_grid_matrix(left_raw, right_raw)
195+
else:
196+
if isinstance(right_, Grid):
197+
result = matmul_matrix_grid(left_raw, right_raw)
198+
else:
199+
result = matmul_matrix_matrix(left_raw, right_raw)
200+
201+
if put_in_grid:
202+
result = Grid(result)
203+
elif put_in_matrix:
204+
result = Matrix(result)
205+
206+
if put_in_metamatrix:
207+
result = MetaMatrix(result)
208+
209+
return result
210+
211+
212+
def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor):
213+
if not is_matrix_shaped(left):
214+
raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}")
215+
216+
if not is_grid_shaped(right):
217+
raise ValueError(
218+
"'right' should be a 3D array with shape[0] == 2 or a "
219+
f"4D array with shape[0] == 3 but has shape {right.shape}"
220+
)
221+
222+
# flatten the grid to take advantage of torch batch matrix multiply
223+
right_flat = right.reshape(right.shape[0], -1)
224+
result_flat = left @ right_flat
225+
# restore the grid shape
226+
result = result_flat.reshape((-1,) + result_flat.shape[1:])
227+
return result
228+
229+
230+
def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
231+
if not is_grid_shaped(left):
232+
raise ValueError(
233+
"'left' should be a 3D array with shape[0] == 2 or a "
234+
f"4D array with shape[0] == 3 but has shape {left.shape}"
235+
)
236+
237+
if not is_matrix_shaped(right):
238+
raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}")
239+
240+
try:
241+
inv_matrix = torch.inverse(right)
242+
except RuntimeError:
243+
# the matrix is not invertible, so we will have to perform a slow grid to matrix operation
244+
return matmul_grid_matrix_slow(left, right)
245+
246+
# invert the matrix and swap the arguments, taking advantage of
247+
# matrix @ vector == vector_transposed @ matrix_inverse
248+
return matmul_matrix_grid(inv_matrix, left)
249+
250+
251+
def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor):
252+
if not is_grid_shaped(left):
253+
raise ValueError(
254+
"'left' should be a 3D array with shape[0] == 2 or a "
255+
f"4D array with shape[0] == 3 but has shape {left.shape}"
256+
)
257+
258+
if not is_matrix_shaped(right):
259+
raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}")
260+
261+
flat_left = left.reshape(left.shape[0], -1)
262+
result_flat = torch.zeros_like(flat_left)
263+
for i in range(flat_left.shape[1]):
264+
vector = flat_left[:, i][None, :]
265+
result_vector = vector @ right
266+
result_flat[:, i] = result_vector[0, :]
267+
268+
# restore the grid shape
269+
result = result_flat.reshape((-1,) + result_flat.shape[1:])
270+
return result
271+
272+
273+
def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor):
274+
return left @ right
275+
276+
101277
# this will conflict with PR Replacement Apply and Resample #5436
102278
def apply_align_corners(matrix, spatial_size, factory):
103279
inflated_spatial_size = tuple(s + 1 for s in spatial_size)

monai/transforms/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,66 @@ def _create_rotate(
763763
raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].")
764764

765765

766+
def _create_rotate_90(
767+
spatial_dims: int,
768+
axis: Tuple[int, int],
769+
steps: Optional[int] = 1,
770+
eye_func: Callable = np.eye
771+
) -> NdarrayOrTensor:
772+
773+
values = [(1, 0, 0, 1),
774+
(0, -1, 1, 0),
775+
(-1, 0, 0, -1),
776+
(0, 1, -1, 0)]
777+
778+
if spatial_dims == 2:
779+
if axis != (0, 1):
780+
raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}")
781+
elif spatial_dims == 3:
782+
if axis not in ((0, 1), (0, 2), (1, 2)):
783+
raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) "
784+
f"but is {axis}")
785+
else:
786+
raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}")
787+
788+
steps_ = steps % 4
789+
790+
affine = eye_func(spatial_dims + 1)
791+
792+
if spatial_dims == 2:
793+
a, b = 0, 1
794+
else:
795+
a, b = axis
796+
797+
affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps]
798+
return affine
799+
800+
801+
def _create_flip(
802+
spatial_dims: int,
803+
spatial_axis: Union[Sequence[int], int],
804+
eye_func: Callable = np.eye
805+
):
806+
affine = eye_func(spatial_dims + 1)
807+
if isinstance(spatial_axis, int):
808+
if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims:
809+
raise ValueError("'spatial_axis' values must be between "
810+
f"{-spatial_dims} and {spatial_dims-1} inclusive "
811+
f"('spatial_axis' is {spatial_axis})")
812+
affine[spatial_axis, spatial_axis] = -1
813+
else:
814+
if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis):
815+
raise ValueError("'spatial_axis' values must be between "
816+
f"{-spatial_dims} and {spatial_dims-1} inclusive "
817+
f"('spatial_axis' is {spatial_axis})")
818+
819+
for i in range(spatial_dims):
820+
if i in spatial_axis:
821+
affine[i, i] = -1
822+
823+
return affine
824+
825+
766826
def create_shear(
767827
spatial_dims: int,
768828
coefs: Union[Sequence[float], float],

0 commit comments

Comments
 (0)