Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def get_module(use_v2):


class ClassificationPresetTrain:
# Note: this transform assumes that the input to forward() are always PIL
# images, regardless of the backend parameter. We may change that in the
# future though, if we change the output type from the dataset.
def __init__(
self,
*,
Expand All @@ -30,42 +33,42 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did s/module/T/ in the file to make it consistent with the detection one


transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
transforms.append(module.RandomHorizontalFlip(hflip_prob))
transforms.append(T.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms.extend(
[
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
transforms.append(module.RandomErasing(p=random_erase_prob))
transforms.append(T.RandomErasing(p=random_erase_prob))

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
Expand All @@ -83,28 +86,28 @@ def __init__(
backend="pil",
use_v2=False,
):
module = get_module(use_v2)
T = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
T.Resize(resize_size, interpolation=interpolation, antialias=True),
T.CenterCrop(crop_size),
]

if backend == "pil":
transforms.append(module.PILToTensor())
transforms.append(T.PILToTensor())

transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
]

self.transforms = module.Compose(transforms)
self.transforms = T.Compose(transforms)

def __call__(self, img):
return self.transforms(img)
142 changes: 89 additions & 53 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,109 @@
from collections import defaultdict

import torch
import transforms as T
import transforms as reference_transforms


def get_modules(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.datapoints
import torchvision.transforms.v2

return torchvision.transforms.v2, torchvision.datapoints
else:
return reference_transforms, None


class DetectionPresetTrain:
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
def __init__(
self,
*,
data_augmentation,
hflip_prob=0.5,
mean=(123.0, 117.0, 104.0),
backend="pil",
use_v2=False,
):

T, datapoints = get_modules(use_v2)

transforms = []
backend = backend.lower()
if backend == "datapoint":
transforms.append(T.ToImageTensor())
elif backend == "tensor":
transforms.append(T.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

if data_augmentation == "hflip":
self.transforms = T.Compose(
[
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
# TODO: FixedSizeCrop below doesn't work on tensors!
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
Comment on lines +47 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In v2 we have RandomCrop that does what FixedSizedCrop does minus the clamping and sanitizing bounding boxes.

T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssd":
self.transforms = T.Compose(
[
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
transforms += [
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=fill),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
elif data_augmentation == "ssdlite":
self.transforms = T.Compose(
[
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
transforms += [
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
]
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')

if backend == "pil":
# Note: we could just convert to pure tensors even in v2.
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]

transforms += [T.ConvertImageDtype(torch.float)]

if use_v2:
transforms += [
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
T.SanitizeBoundingBox(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need ClampBoundingBox here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so since we established that all transforms should clamp already (those that need to, at least)?

]

self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)


class DetectionPresetEval:
def __init__(self):
self.transforms = T.Compose(
[
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
def __init__(self, backend="pil", use_v2=False):
T, _ = get_modules(use_v2)
transforms = []
backend = backend.lower()
# Conversion may look a bit weird but the assumption of this transform is that the input is always a PIL image
# TODO: Is that still true when using v2, from the dataset???????
if backend == "pil":
# Note: we could just convert to pure tensors even in v2?
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
elif backend == "tensor":
transforms += [T.PILToTensor()]
elif backend == "datapoint":
transforms += [T.ToImageTensor()]
else:
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")

transforms += [T.ConvertImageDtype(torch.float)]
self.transforms = T.Compose(transforms)

def __call__(self, img, target):
return self.transforms(img, target)
12 changes: 10 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def get_dataset(name, image_set, transform, data_path):

def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(data_augmentation=args.data_augmentation)
return presets.DetectionPresetTrain(
data_augmentation=args.data_augmentation, backend=args.backend, use_v2=args.use_v2
)
elif args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
trans = weights.transforms()
return lambda img, target: (trans(img), target)
else:
return presets.DetectionPresetEval()
return presets.DetectionPresetEval(backend=args.backend, use_v2=args.use_v2)


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -159,10 +161,16 @@ def get_args_parser(add_help=True):
help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.",
)

parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")

return parser


def main(args):
if args.backend.lower() == "datapoint" and not args.use_v2:
raise ValueError("Use --use-v2 if you want to use the datapoint backend.")

if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down
9 changes: 7 additions & 2 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,13 @@ def __init__(
target_size: Tuple[int, int],
scale_range: Tuple[float, float] = (0.1, 2.0),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
antialias=True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add antialias support because it'd be False otherwise by default for tensors. There's no BC requirements so we could just hard-code antialias=True below in the calls to resize() instead of adding a parameter here, but it doesn't change much. LMK what you prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDC. Unless there is some other opinion, let's keep it the way it is.

):
super().__init__()
self.target_size = target_size
self.scale_range = scale_range
self.interpolation = interpolation
self.antialias = antialias

def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
Expand All @@ -315,14 +317,17 @@ def forward(
new_width = int(orig_width * r)
new_height = int(orig_height * r)

image = F.resize(image, [new_height, new_width], interpolation=self.interpolation)
image = F.resize(image, [new_height, new_width], interpolation=self.interpolation, antialias=self.antialias)

if target is not None:
target["boxes"][:, 0::2] *= new_width / orig_width
target["boxes"][:, 1::2] *= new_height / orig_height
if "masks" in target:
target["masks"] = F.resize(
target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST
target["masks"],
[new_height, new_width],
interpolation=InterpolationMode.NEAREST,
antialias=self.antialias,
)

return image, target
Expand Down