-
Notifications
You must be signed in to change notification settings - Fork 7.2k
improve UX for v2 Compose #7758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,8 @@ | |
make_video, | ||
set_rng_seed, | ||
) | ||
|
||
from torch import nn | ||
from torch.testing import assert_close | ||
from torchvision import datapoints | ||
|
||
|
@@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self): | |
def test_transform_unknown_fill_error(self): | ||
with pytest.raises(TypeError, match="Got inappropriate fill arg"): | ||
transforms.RandomAffine(degrees=0, fill="fill") | ||
|
||
|
||
class TestCompose: | ||
class BuiltinTransform(transforms.Transform): | ||
def _transform(self, inpt, params): | ||
return inpt | ||
|
||
class PackedInputTransform(nn.Module): | ||
def forward(self, sample): | ||
image, label = sample | ||
return image, label | ||
|
||
class UnpackedInputTransform(nn.Module): | ||
def forward(self, image, label): | ||
return image, label | ||
|
||
@pytest.mark.parametrize( | ||
"transform_clss", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not call this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would have to be |
||
[ | ||
[BuiltinTransform], | ||
[PackedInputTransform], | ||
[UnpackedInputTransform], | ||
[BuiltinTransform, BuiltinTransform], | ||
[PackedInputTransform, PackedInputTransform], | ||
[UnpackedInputTransform, UnpackedInputTransform], | ||
[BuiltinTransform, PackedInputTransform, BuiltinTransform], | ||
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform], | ||
[PackedInputTransform, BuiltinTransform, PackedInputTransform], | ||
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform], | ||
], | ||
) | ||
@pytest.mark.parametrize("unpack", [True, False]) | ||
def test_packed_unpacked(self, transform_clss, unpack): | ||
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) | ||
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) | ||
assert not (needs_packed_inputs and needs_unpacked_inputs) | ||
|
||
transform = transforms.Compose([cls() for cls in transform_clss]) | ||
|
||
image = make_image() | ||
label = 3 | ||
packed_input = (image, label) | ||
|
||
def call_transform(): | ||
if unpack: | ||
return transform(*packed_input) | ||
else: | ||
return transform(packed_input) | ||
|
||
if needs_unpacked_inputs and not unpack: | ||
with pytest.raises(TypeError, match="missing 1 required positional argument"): | ||
call_transform() | ||
elif needs_packed_inputs and unpack: | ||
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): | ||
call_transform() | ||
else: | ||
output = call_transform() | ||
|
||
assert isinstance(output, tuple) and len(output) == 2 | ||
assert output[0] is image | ||
assert output[1] is label |
Uh oh!
There was an error while loading. Please reload this page.