-
Notifications
You must be signed in to change notification settings - Fork 7.1k
ADD: ApplySomeOf transforms to Random Choice #7586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import random | ||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union | ||
|
||
import torch | ||
|
@@ -129,16 +130,24 @@ class RandomChoice(Transform): | |
p (list of floats or None, optional): probability of each transform being picked. | ||
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None`` | ||
(default), all transforms have the same probability. | ||
max_transforms(int, optional): The maximum number of transforms that can be applied. | ||
If specified, ``p`` is ignored and a random number of transforms sampled from | ||
[1, ``max_transforms``] is applied. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's sample from [0, max_transforms] as in the original feature requets |
||
""" | ||
|
||
def __init__( | ||
self, | ||
transforms: Sequence[Callable], | ||
p: Optional[List[float]] = None, | ||
max_transforms: Optional[int] = None, | ||
) -> None: | ||
if not isinstance(transforms, Sequence): | ||
raise TypeError("Argument transforms should be a sequence of callables") | ||
|
||
# p and max_transforms are mutually exclusive | ||
if p is not None and max_transforms is not None: | ||
raise ValueError("Only one of `p` and `max_transforms` should be specified.") | ||
|
||
if p is None: | ||
p = [1] * len(transforms) | ||
elif len(p) != len(transforms): | ||
|
@@ -149,11 +158,19 @@ def __init__( | |
self.transforms = transforms | ||
total = sum(p) | ||
self.p = [prob / total for prob in p] | ||
self.max_transforms = max_transforms | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
idx = int(torch.multinomial(torch.tensor(self.p), 1)) | ||
transform = self.transforms[idx] | ||
Comment on lines
164
to
165
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These 2 lines here should probably be part of the |
||
return transform(*inputs) | ||
|
||
if self.p: | ||
return transform(*inputs) | ||
|
||
else: | ||
selected_transforms = random.sample(self.transforms, k=random.randint(1, self.max_transforms)) | ||
random.shuffle(selected_transforms) | ||
return Compose(selected_transforms)(*inputs) | ||
Comment on lines
+171
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to implement that without using the builtin |
||
|
||
|
||
class RandomOrder(Transform): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps this is due to the fact that this file has change a lot in the past, but this test is now part of
test_pure_tensor_heuristic()
and it shouldn't be there. We have adef test_random_choice()
somewhere, maybe we should make it a class and put those tests together.Also, it'd be nice to test the behaviour of the new
max_transforms
parameter