Skip to content

Commit

Permalink
Add in-plane rotation augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed May 11, 2023
1 parent d18a6f7 commit e078107
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 24 deletions.
3 changes: 2 additions & 1 deletion configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ save_dir: ${machine.root_dir}/results/${name_exp}
name_exp: train
use_pretrained: True
use_augmentation: True
use_random_geometric: True
use_random_rotation: True
use_random_scale_translation: True
train_datasets:
- tless_train
- hb
Expand Down
54 changes: 34 additions & 20 deletions src/dataloader/bop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import cv2
import os.path as osp
from src.utils.augmentation import Augmentator, CenterCropRandomResizedCrop
from src.utils.augmentation import Augmentator, CenterCropRandomResizedCrop, RandomRotation
from tqdm import tqdm
from src.poses.utils import (
get_obj_poses_from_template_level,
Expand All @@ -31,7 +31,8 @@ def __init__(
obj_ids,
img_size,
use_augmentation=False,
use_random_geometric=False,
use_random_rotation = False,
use_random_scale_translation = False,
cropping_with_bbox=True,
reset_metaData=False,
isTesting=False,
Expand All @@ -45,9 +46,11 @@ def __init__(
self.mask_size = 25 if img_size == 64 else int(img_size // 8)
self.cropping_with_bbox = cropping_with_bbox
self.use_augmentation = use_augmentation
self.use_random_geometric = use_random_geometric
self.use_random_rotation = use_random_rotation
self.use_random_scale_translation = use_random_scale_translation
self.augmentator = Augmentator()
self.random_cropper = CenterCropRandomResizedCrop()
self.random_rotator = RandomRotation()

self.load_template_poses(template_dir=template_dir)
self.load_testing_indexes()
Expand All @@ -74,7 +77,8 @@ def __init__(
self.metaData = self.subsample(self.metaData, 10)
self.isTesting = True
self.use_augmentation = False
self.use_random_geometric = False
self.use_random_rotation = False
self.use_random_scale_translation = False
else:
logging.warning(f"Split {split} and mode {isTesting} not recognized")
raise NotImplementedError
Expand All @@ -94,6 +98,9 @@ def __init__(
transforms.Lambda(lambda mask: torch.from_numpy(mask).unsqueeze(0)),
]
)
self.random_rotation_transfrom = transforms.Compose([
transforms.RandomRotation(degrees = (-90,90))
])
logging.info(
f"Length of dataloader: {self.__len__()} with mode {self.isTesting} containing objects {np.unique(self.metaData['obj_id'])}"
)
Expand Down Expand Up @@ -168,7 +175,7 @@ def get_bbox(self, img, idx=None):

def crop(self, imgs, bboxes):
if self.cropping_with_bbox:
if self.use_random_geometric and not self.isTesting:
if self.use_random_scale_translation and not self.isTesting:
imgs_cropped = self.random_cropper(imgs, bboxes)
else:
imgs_cropped = []
Expand All @@ -178,6 +185,7 @@ def crop(self, imgs, bboxes):

def load_testing_indexes(self):
self.testing_indexes = load_index_level0_in_level2("all")


def __getitem__(self, idx):
if not self.isTesting:
Expand All @@ -193,6 +201,8 @@ def __getitem__(self, idx):
template = self.rgb_transform(template)
template_mask = self.mask_transform(template_mask)

if self.use_random_rotation:
[query, template, template_mask] = self.random_rotator([query, template, template_mask])
# generate a random resized crop parameters
return {
"query": query,
Expand Down Expand Up @@ -294,8 +304,9 @@ def __getitem__(self, idx):
cropping_with_bbox=True,
reset_metaData=False,
use_augmentation=True,
use_random_geometric=True,
isTesting=True,
use_random_scale_translation=True,
use_random_rotation=True,
isTesting=False,
)
# train_data = DataLoader(
# dataset, batch_size=16, shuffle=False, num_workers=10
Expand All @@ -308,16 +319,19 @@ def __getitem__(self, idx):
# logging.info(f"{dataset_name} is running correctly!")
for idx in range(len(dataset)):
sample = dataset[idx]
for k in sample:
print(k, sample[k].shape)
break
# query = transform_inverse(sample["query"])
# template = transform_inverse(sample["template"])
# query = query.permute(1, 2, 0).numpy()
# query = Image.fromarray(np.uint8(query * 255))
# query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
# template = template.permute(1, 2, 0).numpy()
# template = Image.fromarray(np.uint8(template * 255))
# template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
# if idx == 10:
# break
# for k in sample:
# print(k, sample[k].shape)
# break
query = transform_inverse(sample["query"])
template = transform_inverse(sample["template"])
query = query.permute(1, 2, 0).numpy()
query = Image.fromarray(np.uint8(query * 255))
query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
template = template.permute(1, 2, 0).numpy()
template = Image.fromarray(np.uint8(template * 255))
template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
mask = sample["template_mask"].permute(1, 2, 0).numpy()[:, :, 0]
mask = Image.fromarray(np.uint8(mask * 255))
mask.save(f"./tmp/{dataset_name}_{split}_{idx}_mask.png")
if idx == 10:
break
26 changes: 25 additions & 1 deletion src/utils/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import logging
from torchvision.transforms import RandomResizedCrop, ToTensor
from torchvision import transforms


class PillowRGBAugmentation:
Expand Down Expand Up @@ -160,7 +161,7 @@ def transform_bbox(self, bbox, scale, aspect_ratio):
# Scale the bbox around the center point
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]

scaled_width = width * scale
scaled_height = height * scale * aspect_ratio
scaled_bbox = [
Expand Down Expand Up @@ -192,3 +193,26 @@ def __call__(self, imgs, bboxes):
# crop image with bbox_transfromed
imgs_cropped_transformed.append(imgs[idx].crop(bbox_transformed))
return imgs_cropped_transformed


class RandomRotation:
def __init__(self, degrees_range=[-90, 90]):
self.degrees_range = degrees_range

def __call__(self, imgs):
angle = random.uniform(*self.degrees_range)
fill_value = np.random.randint(0, 255, size=4)
if not isinstance(imgs, list):
imgs = [imgs]

imgs_rotated = []
for idx in range(len(imgs)):
num_channels = imgs[idx].shape[0]
if num_channels == 1:
fill_value = np.zeros(4, dtype=np.uint8)
imgs_rotated.append(
transforms.functional.rotate(
imgs[idx], angle, fill=fill_value[: num_channels].tolist()
)
)
return imgs_rotated
11 changes: 9 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def train(cfg: DictConfig):

val_dataloaders = {}
for data_name in cfg.train_datasets:
if data_name == "hope":
continue
config_dataloader = cfg.data[data_name].dataloader
splits = [
split
Expand Down Expand Up @@ -104,7 +106,10 @@ def train(cfg: DictConfig):
config_dataloader.reset_metaData = False
config_dataloader.isTesting = False
config_dataloader.use_augmentation = cfg.use_augmentation
config_dataloader.use_random_geometric = cfg.use_random_geometric
config_dataloader.use_random_rotation = cfg.use_random_rotation
config_dataloader.use_random_scale_translation = (
cfg.use_random_scale_translation
)
train_dataloader = DataLoader(
instantiate(config_dataloader),
batch_size=cfg.machine.batch_size,
Expand All @@ -118,7 +123,9 @@ def train(cfg: DictConfig):
train_dataloaders[data_name] = train_dataloader
train_dataloaders = concat_dataloader(train_dataloaders)

logging.info(f"Fitting the model: train_size={len(train_dataloaders)}, val_size={len(val_dataloaders)}")
logging.info(
f"Fitting the model: train_size={len(train_dataloaders)}, val_size={len(val_dataloaders)}"
)
trainer.fit(
model,
train_dataloaders=train_dataloaders,
Expand Down

0 comments on commit e078107

Please sign in to comment.