Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
make_video,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torchvision import datapoints

Expand Down Expand Up @@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt

class PackedInputTransform(nn.Module):
def forward(self, sample):
image, label = sample
return image, label

class UnpackedInputTransform(nn.Module):
def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call this transform_class ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would have to be transform_classes, since it is a list of classes. And since we use cls for singular, I'm usually just append an s to it. I'll leave it up to you.

[
[BuiltinTransform],
[PackedInputTransform],
[UnpackedInputTransform],
[BuiltinTransform, BuiltinTransform],
[PackedInputTransform, PackedInputTransform],
[UnpackedInputTransform, UnpackedInputTransform],
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])

image = make_image()
label = 3
packed_input = (image, label)

def call_transform():
if unpack:
return transform(*packed_input)
else:
return transform(packed_input)

if needs_unpacked_inputs and not unpack:
with pytest.raises(TypeError, match="missing 1 required positional argument"):
call_transform()
elif needs_packed_inputs and unpack:
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
call_transform()
else:
output = call_transform()

assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
9 changes: 6 additions & 3 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down