Skip to content

Commit c0c5558

Browse files
committed
Adding in placeholders and functions required for functional, array and dictionary transforms to operate while waiting for PR #5436
Signed-off-by: Ben Murray <ben.murray@gmail.com>
1 parent 6e4ec6e commit c0c5558

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

monai/transforms/lazy/functional.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,44 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
# placeholder to be replaced by apply in Apply And Resample PR #5436
12+
from typing import Sequence, Union
13+
14+
import itertools as it
15+
16+
import numpy as np
17+
18+
import torch
19+
20+
21+
# placeholder that will conflict with PR Replacement Apply and Resample #5436
1322
def apply(*args, **kwargs):
1423
raise NotImplementedError()
24+
25+
26+
# this will conflict with PR Replacement Apply and Resample #5436
27+
def extents_from_shape(shape, dtype=np.float32):
28+
extents = [[0, shape[i]] for i in range(1, len(shape))]
29+
30+
extents = it.product(*extents)
31+
return [np.asarray(e + (1,), dtype=dtype) for e in extents]
32+
33+
34+
def shape_from_extents(
35+
src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor]
36+
):
37+
if isinstance(extents, (list, tuple)):
38+
if isinstance(extents[0], np.ndarray):
39+
aextents = np.asarray(extents)
40+
else:
41+
aextents = torch.stack(extents)
42+
aextents = aextents.numpy()
43+
else:
44+
if isinstance(extents, np.ndarray):
45+
aextents = extents
46+
else:
47+
aextents = extents.numpy()
48+
49+
mins = aextents.min(axis=0)
50+
maxes = aextents.max(axis=0)
51+
values = np.round(maxes - mins).astype(int)[:-1].tolist()
52+
return (src_shape[0],) + tuple(values)

monai/transforms/meta_matrix.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,99 @@
88
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
11+
from typing import Optional, Sequence, Union
1112

13+
import numpy as np
14+
15+
import torch
1216

1317
# placeholder to be replaced by MetaMatrix in Apply And Resample PR #5436
18+
from monai.transforms.utils import _create_rotate, _create_shear, _create_scale, _create_translate
19+
20+
from monai.utils import TransformBackends
21+
22+
23+
# this will conflict with PR Replacement Apply and Resample #5436
1424
class MetaMatrix:
1525

1626
def __init__(self):
1727
raise NotImplementedError()
28+
29+
30+
# this will conflict with PR Replacement Apply and Resample #5436
31+
class MatrixFactory:
32+
33+
def __init__(self,
34+
dims: int,
35+
backend: TransformBackends,
36+
device: Optional[torch.device] = None):
37+
38+
if backend == TransformBackends.NUMPY:
39+
if device is not None:
40+
raise ValueError("'device' must be None with TransformBackends.NUMPY")
41+
self._device = None
42+
self._sin = lambda th: np.sin(th, dtype=np.float32)
43+
self._cos = lambda th: np.cos(th, dtype=np.float32)
44+
self._eye = lambda th: np.eye(th, dtype=np.float32)
45+
self._diag = lambda th: np.diag(th).astype(np.float32)
46+
else:
47+
if device is None:
48+
raise ValueError("'device' must be set with TransformBackends.TORCH")
49+
self._device = device
50+
self._sin = lambda th: torch.sin(torch.as_tensor(th,
51+
dtype=torch.float32,
52+
device=self._device))
53+
self._cos = lambda th: torch.cos(torch.as_tensor(th,
54+
dtype=torch.float32,
55+
device=self._device))
56+
self._eye = lambda rank: torch.eye(rank,
57+
device=self._device,
58+
dtype=torch.float32);
59+
self._diag = lambda size: torch.diag(torch.as_tensor(size,
60+
device=self._device,
61+
dtype=torch.float32))
62+
63+
self._backend = backend
64+
self._dims = dims
65+
66+
@staticmethod
67+
def from_tensor(data):
68+
return MatrixFactory(len(data.shape)-1,
69+
get_backend_from_tensor_like(data),
70+
get_device_from_tensor_like(data))
71+
72+
def identity(self):
73+
matrix = self._eye(self._dims + 1)
74+
return MetaMatrix(matrix, {})
75+
76+
def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args):
77+
matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye)
78+
return MetaMatrix(matrix, extra_args)
79+
80+
def rotate_90(self, rotations, axis, **extra_args):
81+
matrix = _create_rotate_90(self._dims, rotations, axis)
82+
return MetaMatrix(matrix, extra_args)
83+
84+
def flip(self, axis, **extra_args):
85+
matrix = _create_flip(self._dims, axis, self._eye)
86+
return MetaMatrix(matrix, extra_args)
87+
88+
def shear(self, coefs: Union[Sequence[float], float], **extra_args):
89+
matrix = _create_shear(self._dims, coefs, self._eye)
90+
return MetaMatrix(matrix, extra_args)
91+
92+
def scale(self, factors: Union[Sequence[float], float], **extra_args):
93+
matrix = _create_scale(self._dims, factors, self._diag)
94+
return MetaMatrix(matrix, extra_args)
95+
96+
def translate(self, offsets: Union[Sequence[float], float], **extra_args):
97+
matrix = _create_translate(self._dims, offsets, self._eye)
98+
return MetaMatrix(matrix, extra_args)
99+
100+
101+
# this will conflict with PR Replacement Apply and Resample #5436
102+
def apply_align_corners(matrix, spatial_size, factory):
103+
inflated_spatial_size = tuple(s + 1 for s in spatial_size)
104+
scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size))
105+
scale_mat = factory.scale(scale_factors)
106+
return matmul(scale_mat, matrix)

0 commit comments

Comments
 (0)