Skip to content

Commit

Permalink
Add validation set for training
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed May 10, 2023
1 parent d392bdb commit d18a6f7
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ It is important to verify that all the datasets are correctly downloaded and pro
for dir in $ROOT_DIR/datasets/*
do
echo ${dir}
find ${dir} -name "*.png"|wc -l
find ${dir} -name "*.png" | wc -l
done
```

Expand Down
2 changes: 1 addition & 1 deletion configs/user/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
wandb_api_key:
wandb_project_name: template-pose-released
local_root_dir:
local_root_dir: ./datasets/
slurm_root_dir: # if not available, put None
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ dependencies:
- pyrender==0.1.45
- python-dateutil==2.8.1
- pytorch-lightning==1.8.1
- pytorch3d==0.3.0
- pytorch3d==0.7.1
- pytz==2021.1
- pyyaml==5.4.1
- requests-oauthlib==1.3.0
Expand Down
90 changes: 74 additions & 16 deletions src/dataloader/bop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import os.path as osp
from src.utils.augmentation import Augmentator, CenterCropRandomResizedCrop
from tqdm import tqdm
from src.poses.utils import (
get_obj_poses_from_template_level,
load_index_level0_in_level2,
crop_frame,
)

# set level logging
logging.basicConfig(level=logging.INFO)
Expand All @@ -29,6 +34,7 @@ def __init__(
use_random_geometric=False,
cropping_with_bbox=True,
reset_metaData=False,
isTesting=False,
**kwargs,
):
self.root_dir = root_dir
Expand All @@ -39,11 +45,12 @@ def __init__(
self.mask_size = 25 if img_size == 64 else int(img_size // 8)
self.cropping_with_bbox = cropping_with_bbox
self.use_augmentation = use_augmentation
self.augmentator = Augmentator()
self.use_random_geometric = use_random_geometric
self.augmentator = Augmentator()
self.random_cropper = CenterCropRandomResizedCrop()

self.load_template_poses(template_dir=template_dir)
self.load_testing_indexes()
if isinstance(obj_ids, str):
obj_ids = [int(obj_id) for obj_id in obj_ids.split(",")]
logging.info(f"ATTENTION: Loading {len(obj_ids)} objects!")
Expand All @@ -57,13 +64,20 @@ def __init__(
if obj_ids is not None
else np.unique(self.metaData["obj_id"]).tolist()
)
if self.split.startswith("train") or self.split.startswith("val"):
if (
self.split.startswith("train") or self.split.startswith("val")
) and not isTesting:
# keep only 90% of the data for training for each object
self.metaData = self.subsample(self.metaData, 90)
self.isTesting = False
elif self.split.startswith("test"):
elif self.split.startswith("test") or isTesting:
self.metaData = self.subsample(self.metaData, 10)
self.isTesting = True
self.use_augmentation = False
self.use_random_geometric = False
else:
logging.warning(f"Split {split} and mode {isTesting} not recognized")
raise NotImplementedError
self.rgb_transform = transforms.Compose(
[
transforms.Resize((self.img_size, self.img_size)),
Expand All @@ -81,7 +95,7 @@ def __init__(
]
)
logging.info(
f"Length of dataloader: {self.__len__()} containing objects {np.unique(self.metaData['obj_id'])}"
f"Length of dataloader: {self.__len__()} with mode {self.isTesting} containing objects {np.unique(self.metaData['obj_id'])}"
)

def load_template_poses(self, template_dir):
Expand Down Expand Up @@ -154,14 +168,17 @@ def get_bbox(self, img, idx=None):

def crop(self, imgs, bboxes):
if self.cropping_with_bbox:
if self.use_random_geometric:
if self.use_random_geometric and not self.isTesting:
imgs_cropped = self.random_cropper(imgs, bboxes)
else:
imgs_cropped = []
for i in range(len(imgs)):
imgs_cropped.append(imgs[i].crop(bboxes[i]))
return imgs_cropped

def load_testing_indexes(self):
self.testing_indexes = load_index_level0_in_level2("all")

def __getitem__(self, idx):
if not self.isTesting:
query = self.load_image(idx, type_img="real")
Expand All @@ -182,6 +199,43 @@ def __getitem__(self, idx):
"template": template,
"template_mask": template_mask,
}
else:
query_pose = self.metaData.iloc[idx]["pose"]
obj_id = self.metaData.iloc[idx]["obj_id"]
query = self.load_image(idx, type_img="real")
query_bbox = self.get_bbox(None, idx=idx)
imgs, bboxes = [query], [query_bbox]

# load all templates
for idx in self.testing_indexes:
tmp = Image.open(f"{self.template_dir}/obj_{obj_id:06d}/{idx:06d}.png")
imgs.append(tmp)
bboxes.append(self.get_bbox(tmp))
# crop and normalize image
imgs = self.crop(imgs, bboxes)
query = self.rgb_transform(imgs[0])
templates = [
self.rgb_transform(imgs[i].convert("RGB")) for i in range(1, len(imgs))
]
template_masks = [
self.mask_transform(imgs[i].getchannel("A"))
for i in range(1, len(imgs))
]
templates = torch.stack(templates, dim=0)
template_masks = torch.stack(template_masks, dim=0)

# loading poses
query_pose = torch.from_numpy(np.array(query_pose).reshape(4, 4)[:3, :3])
template_poses = torch.from_numpy(
self.templates_poses[self.testing_indexes]
)[:, :3, :3]
return {
"query": query,
"query_pose": query_pose,
"templates": templates,
"template_masks": template_masks,
"template_poses": template_poses,
}


if __name__ == "__main__":
Expand All @@ -191,7 +245,7 @@ def __getitem__(self, idx):
root_dir = "/gpfsscratch/rech/xjd/uyb58rn/datasets/template-pose-released/datasets"
dataset_names = [
"hb",
# "tudl",
# "tudl",
# "hope",
# "icmi",
# "icbin",
Expand Down Expand Up @@ -241,6 +295,7 @@ def __getitem__(self, idx):
reset_metaData=False,
use_augmentation=True,
use_random_geometric=True,
isTesting=True,
)
# train_data = DataLoader(
# dataset, batch_size=16, shuffle=False, num_workers=10
Expand All @@ -253,13 +308,16 @@ def __getitem__(self, idx):
# logging.info(f"{dataset_name} is running correctly!")
for idx in range(len(dataset)):
sample = dataset[idx]
query = transform_inverse(sample["query"])
template = transform_inverse(sample["template"])
query = query.permute(1, 2, 0).numpy()
query = Image.fromarray(np.uint8(query * 255))
query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
template = template.permute(1, 2, 0).numpy()
template = Image.fromarray(np.uint8(template * 255))
template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
if idx == 10:
break
for k in sample:
print(k, sample[k].shape)
break
# query = transform_inverse(sample["query"])
# template = transform_inverse(sample["template"])
# query = query.permute(1, 2, 0).numpy()
# query = Image.fromarray(np.uint8(query * 255))
# query.save(f"./tmp/{dataset_name}_{split}_{idx}_query.png")
# template = template.permute(1, 2, 0).numpy()
# template = Image.fromarray(np.uint8(template * 255))
# template.save(f"./tmp/{dataset_name}_{split}_{idx}_template.png")
# if idx == 10:
# break
66 changes: 62 additions & 4 deletions src/model/base_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import wandb
import torchvision.transforms as transforms
from src.model.loss import GeodesicError


def conv1x1(in_planes, out_planes, stride=1):
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(self, descriptor_size, threshold, **kwargs):
# define the network
super(BaseFeatureExtractor, self).__init__()
self.loss = InfoNCE()
self.metric = GeodesicError()
self.occlusion_sim = OcclusionAwareSimilarity(threshold=threshold)
self.sim_distance = nn.CosineSimilarity(dim=1) # eps=1e-2

Expand Down Expand Up @@ -293,10 +295,66 @@ def training_step(self, batch, idx):
)
return loss

def validation_step(
self,
):
return
def validation_step(self, batch, idx):
for dataset_name in batch:
self.eval_batch(batch[dataset_name], dataset_name)

def eval_batch(self, batch, dataset_name, k=5):
query = batch["query"] # B x C x W x H
templates = batch["templates"] # B x N x C x W x H
template_masks = batch["template_masks"]
template_poses = batch["template_poses"]
feature_query = self.forward(query)

# get predictions
batch_size = query.shape[0]
pred_indexes = torch.zeros(batch_size, k, device=self.device).long()
for idx in range(batch_size):
feature_template = self.forward(templates[idx, :])
mask = template_masks[idx, :]
matrix_sim = self.calculate_similarity_for_search(
feature_query[idx].unsqueeze(0), feature_template, mask, training=False
)
weight_sim, pred_index = matrix_sim.topk(k=k)
pred_indexes[idx] = pred_index.reshape(-1)

retrieved_template = templates[
torch.arange(0, batch_size, device=query.device), pred_indexes[:, 0]
]
retrieved_poses = template_poses[
torch.arange(0, batch_size, device=query.device).unsqueeze(1).repeat(1, k),
pred_indexes,
]
# visualize prediction
save_image_path = os.path.join(
self.log_dir,
f"retrieved_val_step{self.global_step}_rank{self.global_rank}.png",
)
vis_imgs = [
self.transform_inverse(query),
self.transform_inverse(retrieved_template),
]
vis_imgs, ncol = put_image_to_grid(vis_imgs)
vis_imgs_resized = vis_imgs.clone()
vis_imgs_resized = F.interpolate(
vis_imgs_resized, (64, 64), mode="bilinear", align_corners=False
)
save_image(
vis_imgs_resized,
save_image_path,
nrow=ncol * 4,
)
self.logger.experiment.log(
{f"retrieval/{dataset_name}": wandb.Image(save_image_path)},
)

# calculate the scores
error, acc = self.metric(
predR=retrieved_poses,
gtR=batch["query_pose"],
symmetry=torch.zeros(batch_size, device=self.device).long(),
)
self.monitoring_score(dict_scores=acc, split_name=f"{dataset_name}")

def monitoring_score(self, dict_scores, split_name):
for key, value in dict_scores.items():
Expand Down
3 changes: 2 additions & 1 deletion src/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ def so3_relative_angle_with_symmetry(pred, gt, symmetry):

class GeodesicError(nn.Module):
# credit https://github.com/martius-lab/beta-nll
def __init__(self, thresholds=[15]):
def __init__(self, thresholds=[15], topk=[0, 2, 4]):
super(GeodesicError, self).__init__()
self.thresholds = thresholds
self.topk = topk

def forward(self, predR, gtR, symmetry):
if len(predR.shape) == 3: # top 1 Bx3x3
Expand Down
3 changes: 1 addition & 2 deletions src/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def __init__(self, descriptor_size, threshold, **kwargs):
nn.ReLU(inplace=False),
conv1x1(256, descriptor_size),
)

self.metric = GeodesicError()
self.loss = InfoNCE()
self.occlusion_sim = OcclusionAwareSimilarity(threshold=threshold)
self.sim_distance = nn.CosineSimilarity(dim=1) # eps=1e-2
self.geodesic_distance = GeodesicError()

# define optimizer
self.weight_decay = float(kwargs["weight_decay"])
Expand Down
Binary file not shown.
Loading

0 comments on commit d18a6f7

Please sign in to comment.