@@ -1955,8 +1955,9 @@ def __repr__(self):
1955
1955
return self .__class__ .__name__ + '(p={})' .format (self .p )
1956
1956
1957
1957
1958
+ # TODO: move this to references before merging and delete the tests
1958
1959
class RandomMixupCutmix (torch .nn .Module ):
1959
- """Randomly apply Mixum or Cutmix to the provided batch and targets.
1960
+ """Randomly apply Mixup or Cutmix to the provided batch and targets.
1960
1961
The class implements the data augmentations as described in the papers
1961
1962
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
1962
1963
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
@@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
2014
2015
return batch , target
2015
2016
2016
2017
# It's faster to roll the batch by one instead of shuffling it to create image pairs
2017
- batch_flipped = batch .roll (1 )
2018
- target_flipped = target .roll (1 )
2018
+ batch_rolled = batch .roll (1 , 0 )
2019
+ target_rolled = target .roll (1 )
2019
2020
2020
2021
if self .mixup_alpha <= 0.0 :
2021
2022
use_mixup = False
@@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
2025
2026
if use_mixup :
2026
2027
# Implemented as on mixup paper, page 3.
2027
2028
lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .mixup_alpha , self .mixup_alpha ]))[0 ])
2028
- batch_flipped .mul_ (1.0 - lambda_param )
2029
- batch .mul_ (lambda_param ).add_ (batch_flipped )
2029
+ batch_rolled .mul_ (1.0 - lambda_param )
2030
+ batch .mul_ (lambda_param ).add_ (batch_rolled )
2030
2031
else :
2031
2032
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2032
2033
lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .cutmix_alpha , self .cutmix_alpha ]))[0 ])
@@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
2044
2045
x2 = int (torch .clamp (r_x + r_w_half , max = W ))
2045
2046
y2 = int (torch .clamp (r_y + r_h_half , max = H ))
2046
2047
2047
- batch [:, :, y1 :y2 , x1 :x2 ] = batch_flipped [:, :, y1 :y2 , x1 :x2 ]
2048
+ batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
2048
2049
lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
2049
2050
2050
- target_flipped .mul_ (1.0 - lambda_param )
2051
- target .mul_ (lambda_param ).add_ (target_flipped )
2051
+ target_rolled .mul_ (1.0 - lambda_param )
2052
+ target .mul_ (lambda_param ).add_ (target_rolled )
2052
2053
2053
2054
return batch , target
2054
2055
0 commit comments