Skip to content

Commit cecdb0c

Browse files
iulsigheorghita
andauthored
FlipAxis fix (jgraving#57)
Co-authored-by: iulia <iuliaisaia@gmail.com>
1 parent c4f054b commit cecdb0c

File tree

1 file changed

+52
-75
lines changed

1 file changed

+52
-75
lines changed

deepposekit/augment/FlipAxis.py

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,14 @@
1414
# limitations under the License.
1515

1616
import numpy as np
17-
import imgaug.augmenters as iaa
18-
import six.moves as sm
19-
import h5py
20-
2117
from deepposekit.io.BaseGenerator import BaseGenerator
18+
from imgaug.augmenters import meta
19+
from imgaug import parameters as iap
2220

2321
__all__ = ["FlipAxis"]
2422

2523

26-
class FlipAxis(iaa.Flipud):
24+
class FlipAxis(meta.Augmenter):
2725
""" Flips the input image and keypoints across an axis.
2826
2927
A generalized class for flipping images and keypoints
@@ -39,103 +37,82 @@ class FlipAxis(iaa.Flipud):
3937
This can be a deepposekit.io.BaseGenerator for annotations
4038
or an array of integers specifying which keypoint indices
4139
to swap.
40+
41+
p: int, default 0.5
42+
The probability that an image is flipped
4243
4344
axis: int, default 0
4445
Axis over which images are flipped
4546
axis=0 flips up-down (np.flipud)
4647
axis=1 flips left-right (np.fliplr)
47-
48+
49+
seed: None or int or np.random.RandomState, default None
50+
The random state for the augmenter.
51+
4852
name: None or str, default None
4953
Name given to the Augmenter object. The name is used in print().
5054
If left as None, will print 'UnnamedX'
51-
55+
5256
deterministic: bool, default False
5357
If set to true, each batch will be augmented the same way.
5458
55-
random_state: None or int or np.random.RandomState, default None
56-
The random state for the augmenter.
5759
5860
Attributes
5961
----------
62+
p: int
63+
The probability that an image is flipped
64+
6065
axis: int
6166
The axis to reflect the image.
6267
6368
swap_index: array
6469
The keypoint indices to swap when the image is flipped
70+
6571
6672
"""
6773

68-
def __init__(
69-
self,
70-
swap_index,
71-
p=0.5,
72-
axis=0,
73-
name=None,
74-
deterministic=False,
75-
random_state=None,
76-
):
77-
78-
super(FlipAxis, self).__init__(
79-
p=p, name=name, deterministic=deterministic, random_state=random_state
80-
)
81-
74+
def __init__(self, swap_index, p=0.5, axis=0, seed=None, name=None, deterministic=False):
75+
super(FlipAxis, self).__init__(seed=seed, name=name, random_state="deprecated", deterministic=deterministic)
76+
self.p = iap.handle_probability_param(p, "p")
8277
self.axis = axis
8378
if isinstance(swap_index, BaseGenerator):
8479
if hasattr(swap_index, "swap_index"):
8580
self.swap_index = swap_index.swap_index
8681
elif isinstance(swap_index, np.ndarray):
8782
self.swap_index = swap_index
83+
84+
85+
def _augment_batch_(self, batch, random_state, parents, hooks):
86+
samples = self.p.draw_samples((batch.nb_rows,),
87+
random_state=random_state)
88+
for i, sample in enumerate(samples):
89+
if sample >= 0.5:
90+
91+
if batch.images is not None:
92+
if self.axis == 0:
93+
batch.images[i] = np.flipud(batch.images[i])
94+
if self.axis == 1:
95+
batch.images[i] = np.fliplr(batch.images[i])
96+
8897

89-
def _augment_images(self, images, random_state, parents, hooks):
90-
""" Augments the images
91-
92-
Handles the augmentation over a specified axis
93-
94-
Returns
95-
-------
96-
images: array
97-
Array of augmented images.
98-
99-
"""
100-
nb_images = len(images)
101-
samples = self.p.draw_samples((nb_images,), random_state=random_state)
102-
for i in sm.xrange(nb_images):
103-
if samples[i] == 1:
104-
if self.axis == 1:
105-
images[i] = np.fliplr(images[i])
106-
elif self.axis == 0:
107-
images[i] = np.flipud(images[i])
108-
self.samples = samples
109-
return images
110-
111-
def _augment_keypoints(self, keypoints_on_images, random_state, parents, hooks):
112-
""" Augments the keypoints
113-
114-
Handles the augmentation over a specified axis
115-
and swaps the keypoint labels using swap_index.
116-
For example, the left leg will be swapped with the right leg
117-
This is accomplished by reordering the keypoints.
118-
119-
Returns
120-
-------
121-
keypoints_on_images: array
122-
Array of new coordinates of the keypoints.
123-
124-
"""
125-
nb_images = len(keypoints_on_images)
126-
samples = self.p.draw_samples((nb_images,), random_state=random_state)
127-
for i, keypoints_on_image in enumerate(keypoints_on_images):
128-
if samples[i] == 1:
129-
for keypoint in keypoints_on_image.keypoints:
98+
if batch.keypoints is not None:
99+
kpsoi = batch.keypoints[i]
100+
if self.axis == 0:
101+
height = kpsoi.shape[0]
102+
for kp in kpsoi.keypoints:
103+
kp.y = (height-1) - kp.y
130104
if self.axis == 1:
131-
width = keypoints_on_image.shape[1]
132-
keypoint.x = (width - 1) - keypoint.x
133-
elif self.axis == 0:
134-
height = keypoints_on_image.shape[0]
135-
keypoint.y = (height - 1) - keypoint.y
136-
swapped = keypoints_on_image.keypoints.copy()
137-
for r in range(len(keypoints_on_image.keypoints)):
138-
idx = self.swap_index[r]
139-
if idx >= 0:
140-
keypoints_on_image.keypoints[r] = swapped[idx]
141-
return keypoints_on_images
105+
width = kpsoi.shape[1]
106+
for kp in kpsoi.keypoints:
107+
kp.x = (width-1) - kp.x
108+
swapped = kpsoi.keypoints.copy()
109+
for r in range(len(kpsoi.keypoints)):
110+
idx = self.swap_index[r]
111+
if idx >= 0:
112+
kpsoi.keypoints[r] = swapped[idx]
113+
114+
return batch
115+
116+
def get_parameters(self):
117+
"""See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`."""
118+
return [self.p, self.axis, self.swap_index]

0 commit comments

Comments
 (0)