Skip to content

Commit

Permalink
👕 Slightly improve transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Oct 15, 2023
1 parent 6f54bd5 commit d82d42c
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions torch_uncertainty/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
expand: bool = False,
center: Optional[List[int]] = None,
fill: Optional[List[int]] = None,
) -> Union[Tensor, Image.Image]:
) -> None:
super().__init__()
self.random_direction = random_direction
self.interpolation = interpolation
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
self.fill = fill

def forward(
self, img: Union[Tensor, Image.Image], level: float
self, img: Union[Tensor, Image.Image], level: int
) -> Union[Tensor, Image.Image]:
if (
self.random_direction and np.random.uniform() > 0.5
Expand Down Expand Up @@ -250,27 +250,49 @@ def forward(
if level < 0:
raise ValueError("Level must be greater than 0.")
if isinstance(img, Tensor):
img = F.to_pil_image(img)
img: Image.Image = F.to_pil_image(img)
return ImageEnhance.Color(img).enhance(level)


class RepeatTarget(nn.Module):
"""Repeat the targets for ensemble training.
Args:
num_repeats: Number of times to repeat the targets.
"""

def __init__(self, num_repeats: int) -> None:
super().__init__()

if not isinstance(num_repeats, int):
raise ValueError("num_repeats must be an integer.")
raise ValueError(
f"num_repeats must be an integer. Got {num_repeats}."
)
if num_repeats <= 0:
raise ValueError("num_repeats must be greater than 0.")
raise ValueError(
f"num_repeats must be greater than 0. Got {num_repeats}."
)

self.num_repeats = num_repeats

def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
def forward(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
inputs, targets = batch
return inputs, targets.repeat(self.num_repeats)


class MIMOBatchFormat(nn.Module):
"""Format the batch for MIMO training.
Args:
num_estimators: Number of estimators.
rho: Ratio of the correlation between the images for MIMO.
batch_repeat: Number of times to repeat the batch.
Reference:
Havasi, M., et al. Training independent subnetworks for robust
prediction. In ICLR, 2021.
"""

def __init__(
self, num_estimators: int, rho: float = 0.0, batch_repeat: int = 1
) -> None:
Expand All @@ -287,11 +309,11 @@ def __init__(
self.rho = rho
self.batch_repeat = batch_repeat

def shuffle(self, inputs: Tensor):
def shuffle(self, inputs: Tensor) -> Tensor:
idx = torch.randperm(inputs.nelement(), device=inputs.device)
return inputs.view(-1)[idx].view(inputs.size())

def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
def forward(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
inputs, targets = batch
indexes = torch.arange(
0, inputs.shape[0], device=inputs.device, dtype=torch.int64
Expand All @@ -304,7 +326,7 @@ def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
self.shuffle(main_shuffle[:threshold_shuffle]),
main_shuffle[threshold_shuffle:],
],
axis=0,
dim=0,
)
for _ in range(self.num_estimators)
]
Expand All @@ -313,14 +335,14 @@ def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
torch.index_select(inputs, dim=0, index=indices)
for indices in shuffle_indices
],
axis=0,
dim=0,
)
targets = torch.stack(
[
torch.index_select(targets, dim=0, index=indices)
for indices in shuffle_indices
],
axis=0,
dim=0,
)
inputs = rearrange(
inputs, "m b c h w -> (m b) c h w", m=self.num_estimators
Expand Down

0 comments on commit d82d42c

Please sign in to comment.