Skip to content

ADD: ApplySomeOf transforms to Random Choice #7586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5324,6 +5324,10 @@ def was_applied(output, inpt):
for input, output in zip(other_inputs, other_outputs):
assert transform.was_applied(output, input)

def test_assertions_p_and_max_transforms(self):
with pytest.raises(ValueError, match="Only one of `p` and `max_transforms` should be specified."):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1], max_transforms=3)
Comment on lines +5327 to +5329
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is due to the fact that this file has change a lot in the past, but this test is now part of test_pure_tensor_heuristic() and it shouldn't be there. We have a def test_random_choice() somewhere, maybe we should make it a class and put those tests together.
Also, it'd be nice to test the behaviour of the new max_transforms parameter



class TestRandomIoUCrop:
@pytest.mark.parametrize("device", cpu_and_cuda())
Expand Down
19 changes: 18 additions & 1 deletion torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import torch
Expand Down Expand Up @@ -129,16 +130,24 @@ class RandomChoice(Transform):
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
max_transforms(int, optional): The maximum number of transforms that can be applied.
If specified, ``p`` is ignored and a random number of transforms sampled from
[1, ``max_transforms``] is applied.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's sample from [0, max_transforms] as in the original feature requets

"""

def __init__(
self,
transforms: Sequence[Callable],
p: Optional[List[float]] = None,
max_transforms: Optional[int] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")

# p and max_transforms are mutually exclusive
if p is not None and max_transforms is not None:
raise ValueError("Only one of `p` and `max_transforms` should be specified.")

if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
Expand All @@ -149,11 +158,19 @@ def __init__(
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.max_transforms = max_transforms

def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.p), 1))
transform = self.transforms[idx]
Comment on lines 164 to 165
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 2 lines here should probably be part of the if self.p block below

return transform(*inputs)

if self.p:
return transform(*inputs)

else:
selected_transforms = random.sample(self.transforms, k=random.randint(1, self.max_transforms))
random.shuffle(selected_transforms)
return Compose(selected_transforms)(*inputs)
Comment on lines +171 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to implement that without using the builtin random module and instead just rely on pytorch's RNG. torch.randperm can probably be used for that



class RandomOrder(Transform):
Expand Down