14
14
# limitations under the License.
15
15
16
16
import numpy as np
17
- import imgaug .augmenters as iaa
18
- import six .moves as sm
19
- import h5py
20
-
21
17
from deepposekit .io .BaseGenerator import BaseGenerator
18
+ from imgaug .augmenters import meta
19
+ from imgaug import parameters as iap
22
20
23
21
__all__ = ["FlipAxis" ]
24
22
25
23
26
- class FlipAxis (iaa . Flipud ):
24
+ class FlipAxis (meta . Augmenter ):
27
25
""" Flips the input image and keypoints across an axis.
28
26
29
27
A generalized class for flipping images and keypoints
@@ -39,103 +37,82 @@ class FlipAxis(iaa.Flipud):
39
37
This can be a deepposekit.io.BaseGenerator for annotations
40
38
or an array of integers specifying which keypoint indices
41
39
to swap.
40
+
41
+ p: int, default 0.5
42
+ The probability that an image is flipped
42
43
43
44
axis: int, default 0
44
45
Axis over which images are flipped
45
46
axis=0 flips up-down (np.flipud)
46
47
axis=1 flips left-right (np.fliplr)
47
-
48
+
49
+ seed: None or int or np.random.RandomState, default None
50
+ The random state for the augmenter.
51
+
48
52
name: None or str, default None
49
53
Name given to the Augmenter object. The name is used in print().
50
54
If left as None, will print 'UnnamedX'
51
-
55
+
52
56
deterministic: bool, default False
53
57
If set to true, each batch will be augmented the same way.
54
58
55
- random_state: None or int or np.random.RandomState, default None
56
- The random state for the augmenter.
57
59
58
60
Attributes
59
61
----------
62
+ p: int
63
+ The probability that an image is flipped
64
+
60
65
axis: int
61
66
The axis to reflect the image.
62
67
63
68
swap_index: array
64
69
The keypoint indices to swap when the image is flipped
70
+
65
71
66
72
"""
67
73
68
- def __init__ (
69
- self ,
70
- swap_index ,
71
- p = 0.5 ,
72
- axis = 0 ,
73
- name = None ,
74
- deterministic = False ,
75
- random_state = None ,
76
- ):
77
-
78
- super (FlipAxis , self ).__init__ (
79
- p = p , name = name , deterministic = deterministic , random_state = random_state
80
- )
81
-
74
+ def __init__ (self , swap_index , p = 0.5 , axis = 0 , seed = None , name = None , deterministic = False ):
75
+ super (FlipAxis , self ).__init__ (seed = seed , name = name , random_state = "deprecated" , deterministic = deterministic )
76
+ self .p = iap .handle_probability_param (p , "p" )
82
77
self .axis = axis
83
78
if isinstance (swap_index , BaseGenerator ):
84
79
if hasattr (swap_index , "swap_index" ):
85
80
self .swap_index = swap_index .swap_index
86
81
elif isinstance (swap_index , np .ndarray ):
87
82
self .swap_index = swap_index
83
+
84
+
85
+ def _augment_batch_ (self , batch , random_state , parents , hooks ):
86
+ samples = self .p .draw_samples ((batch .nb_rows ,),
87
+ random_state = random_state )
88
+ for i , sample in enumerate (samples ):
89
+ if sample >= 0.5 :
90
+
91
+ if batch .images is not None :
92
+ if self .axis == 0 :
93
+ batch .images [i ] = np .flipud (batch .images [i ])
94
+ if self .axis == 1 :
95
+ batch .images [i ] = np .fliplr (batch .images [i ])
96
+
88
97
89
- def _augment_images (self , images , random_state , parents , hooks ):
90
- """ Augments the images
91
-
92
- Handles the augmentation over a specified axis
93
-
94
- Returns
95
- -------
96
- images: array
97
- Array of augmented images.
98
-
99
- """
100
- nb_images = len (images )
101
- samples = self .p .draw_samples ((nb_images ,), random_state = random_state )
102
- for i in sm .xrange (nb_images ):
103
- if samples [i ] == 1 :
104
- if self .axis == 1 :
105
- images [i ] = np .fliplr (images [i ])
106
- elif self .axis == 0 :
107
- images [i ] = np .flipud (images [i ])
108
- self .samples = samples
109
- return images
110
-
111
- def _augment_keypoints (self , keypoints_on_images , random_state , parents , hooks ):
112
- """ Augments the keypoints
113
-
114
- Handles the augmentation over a specified axis
115
- and swaps the keypoint labels using swap_index.
116
- For example, the left leg will be swapped with the right leg
117
- This is accomplished by reordering the keypoints.
118
-
119
- Returns
120
- -------
121
- keypoints_on_images: array
122
- Array of new coordinates of the keypoints.
123
-
124
- """
125
- nb_images = len (keypoints_on_images )
126
- samples = self .p .draw_samples ((nb_images ,), random_state = random_state )
127
- for i , keypoints_on_image in enumerate (keypoints_on_images ):
128
- if samples [i ] == 1 :
129
- for keypoint in keypoints_on_image .keypoints :
98
+ if batch .keypoints is not None :
99
+ kpsoi = batch .keypoints [i ]
100
+ if self .axis == 0 :
101
+ height = kpsoi .shape [0 ]
102
+ for kp in kpsoi .keypoints :
103
+ kp .y = (height - 1 ) - kp .y
130
104
if self .axis == 1 :
131
- width = keypoints_on_image .shape [1 ]
132
- keypoint .x = (width - 1 ) - keypoint .x
133
- elif self .axis == 0 :
134
- height = keypoints_on_image .shape [0 ]
135
- keypoint .y = (height - 1 ) - keypoint .y
136
- swapped = keypoints_on_image .keypoints .copy ()
137
- for r in range (len (keypoints_on_image .keypoints )):
138
- idx = self .swap_index [r ]
139
- if idx >= 0 :
140
- keypoints_on_image .keypoints [r ] = swapped [idx ]
141
- return keypoints_on_images
105
+ width = kpsoi .shape [1 ]
106
+ for kp in kpsoi .keypoints :
107
+ kp .x = (width - 1 ) - kp .x
108
+ swapped = kpsoi .keypoints .copy ()
109
+ for r in range (len (kpsoi .keypoints )):
110
+ idx = self .swap_index [r ]
111
+ if idx >= 0 :
112
+ kpsoi .keypoints [r ] = swapped [idx ]
113
+
114
+ return batch
115
+
116
+ def get_parameters (self ):
117
+ """See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`."""
118
+ return [self .p , self .axis , self .swap_index ]
0 commit comments