Skip to content

Commit

Permalink
Update code to newer version
Browse files Browse the repository at this point in the history
* Change keypoint convention to `.yaml` files
* Add visibility
* Remove parts, sticks and other legacy stuff
  • Loading branch information
ndrplz committed Oct 7, 2019
1 parent bdeb403 commit 715c3f1
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 214 deletions.
199 changes: 90 additions & 109 deletions datasets/dataset_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,32 @@
from PIL import Image
from torchvision.transforms import ToTensor

from datasets.interop import pascal_parts_colors
from datasets.dataset_stick import StickDataset
from datasets.interop import pascal_texture_planes
from utils.augmentation import MyRandomAffine
from utils.dataset_common import mask_to_torch
from utils.dataset_common import seg_to_image
from utils.visibility import car_planes_visibility
from datasets.dataset_stick import StickDataset
from utils.visibility import VisibilityOracle


def get_planes(image: np.ndarray, meta, pascal_class: str):
def get_planes(image: np.ndarray, meta, pascal_class: str,
vis_oracle: VisibilityOracle):
src_kpoint_dict = meta['kpoints_2d']
az = meta['vpoint'][0]
try:
az = meta['vpoint'][0]
el = meta['vpoint'][1]
except KeyError:
az = meta['azimuth']
el = meta['elevation']

h, w = image.shape[:2]
visible_planes = car_planes_visibility(pascal_class, int(az))

visible_planes = vis_oracle.get_planes_visibility(meta['cad_idx'],
az_deg=int(az),
el_deg=int(el))

planes = []
kpoints_planes = []
visibilities = []
for pl_name in pascal_texture_planes[pascal_class].keys():

pl_kp_names = pascal_texture_planes[pascal_class][pl_name]

src_p2d = np.asarray(
Expand All @@ -44,82 +50,80 @@ def get_planes(image: np.ndarray, meta, pascal_class: str):
return np.stack(planes, 0), kpoints_planes, np.stack(visibilities).astype(np.uint8)


def warp_unwarp_planes(src_planes: np.ndarray, src_planes_kpoints: List[np.ndarray], dst_planes_kpoints: List[np.ndarray],
def warp_unwarp_planes(src_planes: np.ndarray, src_planes_kpoints: List[np.ndarray],
dst_planes_kpoints: List[np.ndarray],
src_visibilities: np.ndarray, dst_visibilities: np.ndarray, pascal_class: str):
h, w = src_planes[0].shape[0:2]
planes_warped = np.zeros_like(src_planes, dtype=src_planes.dtype)
planes_unwarped = np.zeros_like(src_planes, dtype=src_planes.dtype)

keys = list(pascal_texture_planes[pascal_class].keys())
symmetry_set = [keys.index('left'), keys.index('right')]

for i, pl_name in enumerate(keys):
"""
Conditions to skip:
- pl not visible in src
- pl not in symmetry and not visible in dst
- pl in symmetry and neither one from the symmetry visible in dst
"""
if not src_visibilities[i]:
continue
elif not dst_visibilities[i] and i not in symmetry_set:
if i not in symmetry_set and not dst_visibilities[i]:
continue
if i in symmetry_set and 1 not in [dst_visibilities[j] for j in symmetry_set]:
continue

src_plane = src_planes[i]
src_kpoints_plane = src_planes_kpoints[i]
src_plane_kpoints = src_planes_kpoints[i]
j = i
if i in symmetry_set and not dst_visibilities[i]:
j = symmetry_set[0] if i == symmetry_set[1] else symmetry_set[1]

dst_kpoints_plane = dst_planes_kpoints[j]
H12, _ = cv2.findHomography(src_kpoints_plane, dst_kpoints_plane)
H21, _ = cv2.findHomography(dst_kpoints_plane, src_kpoints_plane)
dst_plane_kpoints = dst_planes_kpoints[j]
H12, _ = cv2.findHomography(src_plane_kpoints, dst_plane_kpoints)
H21, _ = cv2.findHomography(dst_plane_kpoints, src_plane_kpoints)

if H12 is not None and H21 is not None:
h, w = src_planes[0].shape[0:2]
src_warped = cv2.warpPerspective(src_plane, H12, dsize=(h, w))
src_unwarped = cv2.warpPerspective(src_warped, H21, dsize=(h, w))
if TextureDatasetWithNormal.is_valid(src_warped, pascal_class) and TextureDatasetWithNormal.is_valid(src_unwarped, pascal_class):
planes_warped[j] = src_warped
planes_unwarped[i] = src_unwarped

planes_warped[j] = src_warped
planes_unwarped[i] = src_unwarped

return planes_warped, planes_unwarped


class TextureDatasetWithNormal(StickDataset):
def __init__(self, folder: Path, ext: str='*.png', resize_factor: float=1.0,
demo_mode: bool=False, do_augmentation: bool=False,
use_LAB: bool=True, use_parts: bool=True,
quantize_central: bool=False):
super(TextureDatasetWithNormal, self).__init__(folder, ext,
resize_factor,
demo_mode=demo_mode,
do_augmentation=do_augmentation,
use_LAB=use_LAB)
def __init__(self, dataset_dir: Path,
visibility_dir: Path,
ext: str = '*.png', resize_factor: float = 1.0,
demo_mode: bool = False, do_augmentation: bool = False,
use_LAB: bool = True, quantize_central: bool = False,
):

super(TextureDatasetWithNormal, self).__init__(dataset_dir, ext, resize_factor,
demo_mode, do_augmentation, use_LAB)

self.quantize_central = quantize_central
self.use_parts = use_parts

self.vis_oracle = VisibilityOracle(self.dataset_meta['pascal_class'],
visibility_dir=visibility_dir)

def __getitem__(self, idx):
return self.prepare_example(image=self.data[self.mode + '_images'][idx],
image_stick=self.data[self.mode + '_sticks'][idx],
image_meta=self.data[self.mode + '_meta'][idx],
image_normal=self.data[self.mode + '_normal'][idx],
image_part=self.data[self.mode + '_part'][idx],)
image_normal=self.data[self.mode + '_normal'][idx])

def _load_data(self):
# Images are loaded only if both normal and part files are found
train_normal_names = set([p.name for p in self.folder.joinpath('normal_train').glob(self.ext)])
test_normal_names = set([p.name for p in self.folder.joinpath('normal_test').glob(self.ext)])

train_part_names = set([p.name for p in self.folder.joinpath('part_train').glob(self.ext)])
test_part_names = set([p.name for p in self.folder.joinpath('part_test').glob(self.ext)])

train_names = sorted(list(train_normal_names.intersection(train_part_names)))
test_names = sorted(list(test_normal_names.intersection(test_part_names)))
# Notice: Images are loaded only if normals are found
train_names = sorted([p.name for p in self.folder.joinpath('normal_train').glob(self.ext)])
test_names = sorted([p.name for p in self.folder.joinpath('normal_test').glob(self.ext)])

train_normal_paths = sorted([self.folder.joinpath('normal_train', n) for n in train_names])
test_normal_paths = sorted([self.folder.joinpath('normal_test', n) for n in test_names])

train_part_paths = sorted([self.folder.joinpath('part_train', n) for n in train_names])
test_part_paths = sorted([self.folder.joinpath('part_test', n) for n in test_names])

train_stick_paths = sorted([self.folder.joinpath('stick_train', n) for n in train_names])
test_stick_paths = sorted([self.folder.joinpath('stick_test', n) for n in test_names])

train_image_paths = sorted([self.folder.joinpath('train', n) for n in train_names])
test_image_paths = sorted([self.folder.joinpath('test', n) for n in test_names])

Expand All @@ -130,10 +134,6 @@ def _load_data(self):
top_n = 100
train_normal_paths = train_normal_paths[:top_n]
test_normal_paths = test_normal_paths[:top_n]
train_part_paths = train_part_paths[:top_n]
test_part_paths = test_part_paths[:top_n]
test_stick_paths = test_stick_paths[:top_n]
train_stick_paths = train_stick_paths[:top_n]
train_image_paths = train_image_paths[:top_n]
test_image_paths = test_image_paths[:top_n]
train_meta_paths = train_meta_paths[:top_n]
Expand All @@ -142,12 +142,8 @@ def _load_data(self):
return {
'train_normal': [self._preprocess(cv2.imread(str(f))) for f in train_normal_paths],
'eval_normal': [self._preprocess(cv2.imread(str(f))) for f in test_normal_paths],
'train_part': [self._preprocess(cv2.imread(str(f), cv2.IMREAD_UNCHANGED), interp=cv2.INTER_NEAREST) for f in train_part_paths],
'eval_part': [self._preprocess(cv2.imread(str(f), cv2.IMREAD_UNCHANGED), interp=cv2.INTER_NEAREST) for f in test_part_paths],
'train_images': [self._preprocess(cv2.imread(str(f))) for f in train_image_paths],
'eval_images': [self._preprocess(cv2.imread(str(f))) for f in test_image_paths],
'train_sticks': [self._preprocess(cv2.imread(str(f))) for f in train_stick_paths],
'eval_sticks': [self._preprocess(cv2.imread(str(f))) for f in test_stick_paths],
'train_meta': [self._load_metadata(f) for f in train_meta_paths],
'eval_meta': [self._load_metadata(f) for f in test_meta_paths]
}
Expand All @@ -165,10 +161,10 @@ def quantization(img):
# Now convert back into uint8, and make original image
center = np.uint8(center)
res = center[label.flatten()]
res2 = res.reshape((img.shape))
res2 = res.reshape(img.shape)
return res2

def prepare_example(self, image, image_stick, image_meta, image_normal, image_part):
def prepare_example(self, image, image_meta, image_normal):

src_image = image
h, w = src_image.shape[:2]
Expand All @@ -177,13 +173,10 @@ def prepare_example(self, image, image_stick, image_meta, image_normal, image_pa
pascal_class = self.dataset_meta['pascal_class']

src_meta = image_meta
src_stick = image_stick
src_normal = image_normal
src_part = image_part
if self.use_parts:
src_part = seg_to_image(image_part, pascal_parts_colors[pascal_class])

src_log_image = src_image.copy()
src_central_crop = src_image[h//2 - offset:h//2 + offset, w//2 - offset:w//2 + offset].copy()
src_central_crop = src_image[h // 2 - offset:h // 2 + offset, w // 2 - offset:w // 2 + offset].copy()
src_central_crop = cv2.resize(src_central_crop, (w, h))
if self.quantize_central:
src_central_crop = self.quantization(src_central_crop)
Expand All @@ -193,24 +186,29 @@ def prepare_example(self, image, image_stick, image_meta, image_normal, image_pa
dst_idx = np.random.randint(len(self))
dst_image = self.data[self.mode + '_images'][dst_idx]
dst_meta = self.data[self.mode + '_meta'][dst_idx]
dst_stick = self.data[self.mode + '_sticks'][dst_idx]
dst_normal = self.data[self.mode + '_normal'][dst_idx]
dst_part = self.data[self.mode + '_part'][dst_idx]
if self.use_parts:
dst_part = seg_to_image(dst_part, pascal_parts_colors[pascal_class])

dst_log_image = dst_image.copy()
dst_central_crop = dst_image[h//2 - offset:h//2 + offset, w//2 - offset:w//2 + offset].copy()
dst_central_crop = dst_image[h // 2 - offset:h // 2 + offset, w // 2 - offset:w // 2 + offset].copy()
dst_central_crop = cv2.resize(dst_central_crop, (w, h))
if self.quantize_central:
dst_central_crop = self.quantization(dst_central_crop)
values, counts = np.unique(dst_central_crop.reshape(-1, 3), return_counts=True, axis=0)
dst_central_crop = np.ones_like(dst_central_crop) * values[np.argmax(counts)]

src_planes, src_kpoints_planes, src_visibilities = get_planes(src_image, src_meta, pascal_class)
dst_planes, dst_kpoints_planes, dst_visibilities = get_planes(dst_image, dst_meta, pascal_class)
src_pl_info = get_planes(src_image, src_meta, pascal_class, self.vis_oracle)
src_planes, src_kpoints_planes, src_visibilities = src_pl_info

dst_pl_info = get_planes(dst_image, dst_meta, pascal_class, self.vis_oracle)
dst_planes, dst_kpoints_planes, dst_visibilities = dst_pl_info

planes_warped, planes_unwarped = warp_unwarp_planes(src_planes, src_kpoints_planes, dst_kpoints_planes,
src_visibilities, dst_visibilities, pascal_class)

# todo: two rows below are necessary for `augment_pascal.py` script
# planes_warped, planes_unwarped = warp_unwarp_planes(dst_planes, dst_kpoints_planes, src_kpoints_planes,
# dst_visibilities, src_visibilities, pascal_class)

# Compute masks for both src and dst images
bg_thresh = 20
bg_color = 255
Expand All @@ -229,6 +227,7 @@ def prepare_example(self, image, image_stick, image_meta, image_normal, image_pa

# Data augmentation: all images and planes undergo the same warp
if self.do_augmentation:
# todo: notice that kpoints are not transformed
affine = MyRandomAffine(degrees=10, translate=(0.1, 0.1),
fillcolor=(0, 0, 0), shear=10,
resample=Image.BICUBIC)
Expand All @@ -237,76 +236,67 @@ def prepare_example(self, image, image_stick, image_meta, image_normal, image_pa
src_planes = affine(*src_planes)
planes_unwarped = affine(*planes_unwarped)
src_normal = affine(src_normal)
src_part = affine(src_part)
src_image = affine(src_image)
src_stick = affine(src_stick)
src_image_masked = affine(src_image_masked,
fillcolor=(255, 255, 255))
src_fg_mask = affine(src_fg_mask)

# Compute masks for customizing losses
kernel = np.ones((15, 15), np.uint8)
src_inner_mask = cv2.erode(src_fg_mask, kernel)
src_border_mask = cv2.dilate(src_fg_mask - src_inner_mask,
np.ones((5, 5), np.uint8))
src_inner_mask = mask_to_torch(src_inner_mask)
src_border_mask = mask_to_torch(src_border_mask)
dst_inner_mask = cv2.erode(dst_fg_mask, kernel)
dst_inner_mask = mask_to_torch(dst_inner_mask)

# Achtung! Make sure the same normalization is used for training
planes_unwarped = self.planes_to_torch(planes_unwarped, to_LAB=self.use_LAB)
planes_warped = self.planes_to_torch(planes_warped, to_LAB=self.use_LAB)
src_planes = self.planes_to_torch(src_planes, to_LAB=self.use_LAB)

src_stick = self.to_torch(src_stick, to_LAB=self.use_LAB)
src_image = self.to_torch(src_image, to_LAB=self.use_LAB)
src_normal = self.to_torch(src_normal, to_LAB=self.use_LAB)
src_part = self.to_torch(src_part, to_LAB=self.use_LAB)

src_log_image = self.to_torch(src_log_image, to_LAB=self.use_LAB)

src_central_crop = self.to_torch(src_central_crop, to_LAB=self.use_LAB)
dst_image = self.to_torch(dst_image, to_LAB=self.use_LAB)
dst_stick = self.to_torch(dst_stick, to_LAB=self.use_LAB)
dst_normal = self.to_torch(dst_normal, to_LAB=self.use_LAB)
dst_part = self.to_torch(dst_part, to_LAB=self.use_LAB)

src_image_masked = self.to_torch(src_image_masked, to_LAB=self.use_LAB)
dst_image_masked = self.to_torch(dst_image_masked, to_LAB=self.use_LAB)
dst_log_image = self.to_torch(dst_log_image, to_LAB=self.use_LAB)
dst_central_crop = self.to_torch(dst_central_crop, to_LAB=self.use_LAB)

# Cross-Entropy Target
dst_cycle = self.data[self.mode + '_part'][dst_idx]
dst_cycle = torch.LongTensor(dst_cycle)
# This is the CAD of src image - that is, the one that must be used to
# generate novel views of the source. However, the cad_idx is only
# available for Pascal3D+
src_cad_idx = -1
try:
src_cad_idx = src_meta['cad_idx']
except KeyError:
pass
if type(src_cad_idx) == str:
src_cad_idx = -1 # todo: for shapenet cads

# Pre-process foreground mask
src_fg_mask = np.any(src_fg_mask, axis=-1, keepdims=True)
src_fg_mask = torch.from_numpy(src_fg_mask.transpose(2, 0, 1)).float()

return {
'src_kpoints_planes': src_kpoints_planes,
'dst_kpoints_planes': dst_kpoints_planes,
'src_vs': src_visibilities,
'dst_vs': dst_visibilities,
'src_normal': src_normal,
'dst_normal': dst_normal,
'src_part': src_part,
'dst_part': dst_part,
'src_part_raw': image_part, # todo: only for augment_pascal.py
'dst_cycle': dst_cycle,
'planes': src_planes, # Texture planes
'planes_warped': planes_warped, # Texture planes
'planes_unwarped': planes_unwarped, # Texture planes
'src_image': src_image, # source image
'dst_image': dst_image, # Destination image
'src_stick': src_stick,
'dst_stick': dst_stick,
'image_meta': src_meta, # for compatibility with StickDataset
'src_meta': src_meta,
'dst_meta': dst_meta,
'src_cad_idx': src_cad_idx,
# todo: `image_meta` breaks mixed training but is still needed for competitors_predict.py
# 'image_meta': src_meta,
'src_image_masked': src_image_masked,
'dst_image_masked': dst_image_masked,
'src_inner_mask': src_inner_mask,
'dst_inner_mask': dst_inner_mask,
'src_border_mask': src_border_mask,
'dst_log_image': dst_log_image,
'src_log_image': src_log_image,
'src_central': src_central_crop,
'dst_central': dst_central_crop,
'src_fg_mask': src_fg_mask
}

def to_torch(self, x: np.ndarray, to_LAB: bool):
Expand All @@ -318,15 +308,6 @@ def to_torch(self, x: np.ndarray, to_LAB: bool):
x = Image.fromarray(x)
return self.normalizer(ToTensor()(x))

@staticmethod
def is_valid(warped_image: np.ndarray, pascal_class: str):
if pascal_class == 'car':
error = np.any(warped_image[10, :]) or np.any(warped_image[-10, :])
else:
error = np.any(warped_image[:, 10]) or np.any(warped_image[:, -10])

return not error

@staticmethod
def planes_to_torch(planes, to_LAB: bool):
planes = [p for p in planes]
Expand Down
Loading

0 comments on commit 715c3f1

Please sign in to comment.