Skip to content

Commit

Permalink
Add geometric augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed May 9, 2023
1 parent 01499a2 commit d392bdb
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 47 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ Optional: This pre-rendered template set can be manually downloaded from [here](
./src/scripts/render_all.sh
```

<details><summary>Click to expand</summary>

It is important to verify that all the datasets are correctly downloaded and processed. For example, by counting the number of images of each folder:

<details><summary>Click to expand</summary>

```
for dir in $ROOT_DIR/datasets/*
Expand Down
1 change: 1 addition & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ save_dir: ${machine.root_dir}/results/${name_exp}
name_exp: train
use_pretrained: True
use_augmentation: True
use_random_geometric: True
train_datasets:
- tless_train
- hb
Expand Down
110 changes: 64 additions & 46 deletions src/dataloader/bop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import cv2
import os.path as osp
from src.utils.augmentation import Augmentator
from src.utils.augmentation import Augmentator, CenterCropRandomResizedCrop
from tqdm import tqdm

# set level logging
Expand All @@ -26,6 +26,7 @@ def __init__(
obj_ids,
img_size,
use_augmentation=False,
use_random_geometric=False,
cropping_with_bbox=True,
reset_metaData=False,
**kwargs,
Expand All @@ -39,6 +40,8 @@ def __init__(
self.cropping_with_bbox = cropping_with_bbox
self.use_augmentation = use_augmentation
self.augmentator = Augmentator()
self.use_random_geometric = use_random_geometric
self.random_cropper = CenterCropRandomResizedCrop()

self.load_template_poses(template_dir=template_dir)
if isinstance(obj_ids, str):
Expand Down Expand Up @@ -92,8 +95,9 @@ def subsample(self, df, percentage):
selected_index = []
index_dataframe = np.arange(0, len(df))
for obj_id in selected_obj_id:
selected_index_obj = index_dataframe[# df["obj_id"] == obj_id]
np.logical_and(df["obj_id"] == obj_id, df["visib_fract"] >= 0.5)]
selected_index_obj = index_dataframe[ # df["obj_id"] == obj_id]
np.logical_and(df["obj_id"] == obj_id, df["visib_fract"] >= 0.5)
]
if percentage > 50:
selected_index_obj = selected_index_obj[
: int(percentage / 100 * len(selected_index_obj))
Expand All @@ -104,7 +108,9 @@ def subsample(self, df, percentage):
] # keep last
selected_index.extend(selected_index_obj.tolist())
df = df.iloc[selected_index]
logging.info(f"Subsampled from {len(index_dataframe)} to {len(df)} ({percentage}%) images")
logging.info(
f"Subsampled from {len(index_dataframe)} to {len(df)} ({percentage}%) images"
)
return df

def __len__(self):
Expand All @@ -116,11 +122,10 @@ def load_image(self, idx, type_img):
inplane = self.metaData.iloc[idx]["inplane"]
rgb = Image.open(template_path)
rgb = rgb.rotate(inplane)
return self.crop(rgb)
return rgb
else:
rgb_path = self.metaData.iloc[idx]["rgb_path"]
rgb = Image.open(rgb_path).convert("RGB")
rgb = self.crop(rgb, idx=idx)
if self.use_augmentation:
rgb = self.augmentator([rgb])[0]
return rgb
Expand All @@ -139,28 +144,39 @@ def make_bbox_square(self, old_bbox):
new_bbox[2] = old_bbox[2] + displacement
return new_bbox

def crop(self, img, idx=None):
def get_bbox(self, img, idx=None):
if idx is not None:
mask_path = self.metaData.iloc[idx]["mask_path"]
bbox = self.make_bbox_square(Image.open(mask_path).getbbox())
else:
bbox = self.make_bbox_square(img.getbbox())
return bbox

def crop(self, imgs, bboxes):
if self.cropping_with_bbox:
if np.array(img).shape[2] == 4:
bbox = self.make_bbox_square(img.getbbox())
return_mask = True
if self.use_random_geometric:
imgs_cropped = self.random_cropper(imgs, bboxes)
else:
mask_path = self.metaData.iloc[idx]["mask_path"]
bbox = self.make_bbox_square(Image.open(mask_path).getbbox())
return_mask = False
rgb = img.crop(bbox)
if return_mask:
return rgb.convert("RGB"), rgb.getchannel("A")
else:
return rgb.convert("RGB")
imgs_cropped = []
for i in range(len(imgs)):
imgs_cropped.append(imgs[i].crop(bboxes[i]))
return imgs_cropped

def __getitem__(self, idx):
query = self.load_image(idx, type_img="real")
query = self.rgb_transform(query)
if not self.isTesting:
template, template_mask = self.load_image(idx, type_img="synth")
query = self.load_image(idx, type_img="real")
template = self.load_image(idx, type_img="synth")
bboxes = [self.get_bbox(None, idx=idx), self.get_bbox(template)]

[query, template] = self.crop([query, template], bboxes)
template_mask = template.getchannel("A")
template = template.convert("RGB")

query = self.rgb_transform(query)
template = self.rgb_transform(template)
template_mask = self.mask_transform(template_mask)

# generate a random resized crop parameters
return {
"query": query,
"template": template,
Expand All @@ -174,12 +190,12 @@ def __getitem__(self, idx):

root_dir = "/gpfsscratch/rech/xjd/uyb58rn/datasets/template-pose-released/datasets"
dataset_names = [
"tudl",
"hb",
"hope",
"icmi",
"icbin",
"ruapc",
# "tudl",
# "hope",
# "icmi",
# "icbin",
# "ruapc",
]

# tless is special
Expand Down Expand Up @@ -222,26 +238,28 @@ def __getitem__(self, idx):
obj_ids=None,
img_size=256,
cropping_with_bbox=True,
reset_metaData=True,
reset_metaData=False,
use_augmentation=True,
use_random_geometric=True,
)
train_data = DataLoader(
dataset, batch_size=16, shuffle=False, num_workers=10
)
train_size, train_loader = len(train_data), iter(train_data)
for idx in tqdm(range(train_size)):
batch = next(train_loader)
if idx >= 500:
# train_data = DataLoader(
# dataset, batch_size=16, shuffle=False, num_workers=10
# )
# train_size, train_loader = len(train_data), iter(train_data)
# for idx in tqdm(range(train_size)):
# batch = next(train_loader)
# if idx >= 500:
# break
# logging.info(f"{dataset_name} is running correctly!")
for idx in range(len(dataset)):
sample = dataset[idx]
query = transform_inverse(sample["query"])
template = transform_inverse(sample["template"])
query = query.permute(1, 2, 0).numpy()
query = Image.fromarray(np.uint8(query * 255))
query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
template = template.permute(1, 2, 0).numpy()
template = Image.fromarray(np.uint8(template * 255))
template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
if idx == 10:
break
logging.info(f"{dataset_name} is running correctly!")
# for idx in range(len(dataset)):
# sample = dataset[idx]
# query = transform_inverse(sample["query"])
# template = transform_inverse(sample["template"])
# query = query.permute(1, 2, 0).numpy()
# query = Image.fromarray(np.uint8(query * 255))
# query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
# template = template.permute(1, 2, 0).numpy()
# template = Image.fromarray(np.uint8(template * 255))
# template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
# break
56 changes: 56 additions & 0 deletions src/utils/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import random
import logging
from torchvision.transforms import RandomResizedCrop, ToTensor


class PillowRGBAugmentation:
Expand Down Expand Up @@ -136,3 +137,58 @@ def __call__(self, imgs):
img_aug = self.blur(img_aug)
img_aug = self.gaussian_noise(img_aug)
return img_aug


class CenterCropRandomResizedCrop:
def __init__(
self,
scale_range=[0.8, 1.0],
ratio_range=[3.0 / 4, 4.0 / 3],
translation_x=[-0.02, 0.02],
translation_y=[-0.02, 0.02],
):
self.scale_range = scale_range
self.ratio_range = ratio_range
self.translation_x = translation_x
self.translation_y = translation_y

def transform_bbox(self, bbox, scale, aspect_ratio):
# Calculate center point of bbox
cx = (bbox[0] + bbox[2]) / 2.0
cy = (bbox[1] + bbox[3]) / 2.0

# Scale the bbox around the center point
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]

scaled_width = width * scale
scaled_height = height * scale * aspect_ratio
scaled_bbox = [
cx - scaled_width / 2.0,
cy - scaled_height / 2.0,
cx + scaled_width / 2.0,
cy + scaled_height / 2.0,
]
return scaled_bbox

def __call__(self, imgs, bboxes):
scale = random.uniform(*self.scale_range)
aspect_ratio = random.uniform(*self.ratio_range)
# translation_x = random.uniform(*self.translation_x)
# translation_y = random.uniform(*self.translation_y)

if not isinstance(imgs, list):
imgs = [imgs]
bboxes = [bboxes]

imgs_cropped_transformed = []
for idx in range(len(imgs)):
bbox_transformed = self.transform_bbox(
bbox=bboxes[idx],
scale=scale,
aspect_ratio=aspect_ratio,
# translation2d=[translation_x, translation_y],
)
# crop image with bbox_transfromed
imgs_cropped_transformed.append(imgs[idx].crop(bbox_transformed))
return imgs_cropped_transformed
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def train(cfg: DictConfig):
config_dataloader.split = split
config_dataloader.reset_metaData = False
config_dataloader.use_augmentation = cfg.use_augmentation
config_dataloader.use_random_geometric = cfg.use_random_geometric
train_dataloader = DataLoader(
instantiate(config_dataloader),
batch_size=cfg.machine.batch_size,
Expand Down

0 comments on commit d392bdb

Please sign in to comment.