8
8
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
9
# See the License for the specific language governing permissions and
10
10
# limitations under the License.
11
+ from typing import Optional , Sequence , Union
11
12
13
+ import numpy as np
14
+
15
+ import torch
12
16
13
17
# placeholder to be replaced by MetaMatrix in Apply And Resample PR #5436
18
+ from monai .transforms .utils import _create_rotate , _create_shear , _create_scale , _create_translate
19
+
20
+ from monai .utils import TransformBackends
21
+
22
+
23
+ # this will conflict with PR Replacement Apply and Resample #5436
14
24
class MetaMatrix :
15
25
16
26
def __init__ (self ):
17
27
raise NotImplementedError ()
28
+
29
+
30
+ # this will conflict with PR Replacement Apply and Resample #5436
31
+ class MatrixFactory :
32
+
33
+ def __init__ (self ,
34
+ dims : int ,
35
+ backend : TransformBackends ,
36
+ device : Optional [torch .device ] = None ):
37
+
38
+ if backend == TransformBackends .NUMPY :
39
+ if device is not None :
40
+ raise ValueError ("'device' must be None with TransformBackends.NUMPY" )
41
+ self ._device = None
42
+ self ._sin = lambda th : np .sin (th , dtype = np .float32 )
43
+ self ._cos = lambda th : np .cos (th , dtype = np .float32 )
44
+ self ._eye = lambda th : np .eye (th , dtype = np .float32 )
45
+ self ._diag = lambda th : np .diag (th ).astype (np .float32 )
46
+ else :
47
+ if device is None :
48
+ raise ValueError ("'device' must be set with TransformBackends.TORCH" )
49
+ self ._device = device
50
+ self ._sin = lambda th : torch .sin (torch .as_tensor (th ,
51
+ dtype = torch .float32 ,
52
+ device = self ._device ))
53
+ self ._cos = lambda th : torch .cos (torch .as_tensor (th ,
54
+ dtype = torch .float32 ,
55
+ device = self ._device ))
56
+ self ._eye = lambda rank : torch .eye (rank ,
57
+ device = self ._device ,
58
+ dtype = torch .float32 );
59
+ self ._diag = lambda size : torch .diag (torch .as_tensor (size ,
60
+ device = self ._device ,
61
+ dtype = torch .float32 ))
62
+
63
+ self ._backend = backend
64
+ self ._dims = dims
65
+
66
+ @staticmethod
67
+ def from_tensor (data ):
68
+ return MatrixFactory (len (data .shape )- 1 ,
69
+ get_backend_from_tensor_like (data ),
70
+ get_device_from_tensor_like (data ))
71
+
72
+ def identity (self ):
73
+ matrix = self ._eye (self ._dims + 1 )
74
+ return MetaMatrix (matrix , {})
75
+
76
+ def rotate_euler (self , radians : Union [Sequence [float ], float ], ** extra_args ):
77
+ matrix = _create_rotate (self ._dims , radians , self ._sin , self ._cos , self ._eye )
78
+ return MetaMatrix (matrix , extra_args )
79
+
80
+ def rotate_90 (self , rotations , axis , ** extra_args ):
81
+ matrix = _create_rotate_90 (self ._dims , rotations , axis )
82
+ return MetaMatrix (matrix , extra_args )
83
+
84
+ def flip (self , axis , ** extra_args ):
85
+ matrix = _create_flip (self ._dims , axis , self ._eye )
86
+ return MetaMatrix (matrix , extra_args )
87
+
88
+ def shear (self , coefs : Union [Sequence [float ], float ], ** extra_args ):
89
+ matrix = _create_shear (self ._dims , coefs , self ._eye )
90
+ return MetaMatrix (matrix , extra_args )
91
+
92
+ def scale (self , factors : Union [Sequence [float ], float ], ** extra_args ):
93
+ matrix = _create_scale (self ._dims , factors , self ._diag )
94
+ return MetaMatrix (matrix , extra_args )
95
+
96
+ def translate (self , offsets : Union [Sequence [float ], float ], ** extra_args ):
97
+ matrix = _create_translate (self ._dims , offsets , self ._eye )
98
+ return MetaMatrix (matrix , extra_args )
99
+
100
+
101
+ # this will conflict with PR Replacement Apply and Resample #5436
102
+ def apply_align_corners (matrix , spatial_size , factory ):
103
+ inflated_spatial_size = tuple (s + 1 for s in spatial_size )
104
+ scale_factors = tuple (s / i for s , i in zip (spatial_size , inflated_spatial_size ))
105
+ scale_mat = factory .scale (scale_factors )
106
+ return matmul (scale_mat , matrix )
0 commit comments