Skip to content

Commit

Permalink
add training code
Browse files Browse the repository at this point in the history
  • Loading branch information
baegwangbin committed May 2, 2022
1 parent 9dd04a1 commit 930abed
Show file tree
Hide file tree
Showing 11 changed files with 2,378 additions and 24 deletions.
27 changes: 23 additions & 4 deletions ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ Official implementation of the paper
> **Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation** \
> ICCV 2021 [oral] \
> [Gwangbin Bae](https://baegwangbin.com), [Ignas Budvytis](https://mi.eng.cam.ac.uk/~ib255/), and [Roberto Cipolla](https://mi.eng.cam.ac.uk/~cipolla/) \
> [[arXiv]](https://arxiv.org/abs/2109.09881)
> [[arXiv]](https://arxiv.org/abs/2109.09881) [[youtube]](https://youtu.be/mTy85tJ2oAQ)
<p align="center">
<img width=50% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_scannet.png?raw=true">
<img width=70% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_scannet.png?raw=true">
</p>

The proposed method estimates the per-pixel surface normal probability distribution, from which the expected angular error can be inferred to quantify the aleatoric uncertainty.
Expand All @@ -33,7 +33,7 @@ Download the pre-trained model weights and sample images.
```python
python download.py && cd examples && unzip examples.zip && cd ..
```
[25 Apr 2022] ***The above script does not work anymore. Please download the models and example images directly from [this link](https://drive.google.com/drive/folders/1Ku25Am69h_HrbtcCptXn4aetjo7sB33F?usp=sharing), and unzip them under `./checkpoints/` and `./examples`.***
[25 Apr 2022] ***`download.py` does not work anymore. Please download the models and example images directly from [this link](https://drive.google.com/drive/folders/1Ku25Am69h_HrbtcCptXn4aetjo7sB33F?usp=sharing), and unzip them under `./checkpoints/` and `./examples/`.***

Running the above will download
* `./checkpoints/nyu.pt` (model trained on NYUv2)
Expand All @@ -60,11 +60,30 @@ python test.py --pretrained scannet --architecture BN
Running the above will save the predicted surface normal and uncertainty under `./examples/results/`. If successful, you will obtain images like below.

<p align="center">
<img width=70% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_generalize.png?raw=true">
<img width=100% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_generalize.png?raw=true">
</p>

The predictions in the figure above are obtained by the network trained only on ScanNet. The network generalizes well to objects unseen during training (e.g., humans, cars, animals). The last row shows interesting examples where the input image only contains edges.

## Training

### Step 1. Download dataset

* **NYUv2 (official)**: The official train/test split contains 795/654 images. The dataset can be downloaded from [this link](https://drive.google.com/drive/folders/1Ku25Am69h_HrbtcCptXn4aetjo7sB33F?usp=sharing). Unzip the file `nyu_dataset.zip` under `./datasets`, so that `./datasets/nyu/train` and `./datasets/nyu/test/` exist.

* **NYUv2 (big)**: Please visit [GeoNet](https://github.com/xjqi/GeoNet) to download a larger training set consisting of 30907 images. This is the training set used to train our model.

* **ScanNet:** Please visit [FrameNet](https://github.com/hjwdzh/framenet/tree/master/src) to download ScanNet with ground truth surface normals.

### Step 2. Train

* If you wish to train on NYUv2 official split, simply run
```python
python train.py
```
* If you wish to train your own model, modify the file `./models/baseline.py` and add `--use_baseline` flag. The default loss function `UG_NLL_ours` assumes uncertainty-guided sampling, so this should be changed to something else (e.g. try `--loss_fn NLL_ours`).


## Citation

If you find our work useful in your research please consider citing our paper:
Expand Down
137 changes: 137 additions & 0 deletions data/dataloader_nyu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import random
import numpy as np
from PIL import Image

import torch
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF

import data.utils as data_utils


# Modify the following
NYU_PATH = './datasets/nyu/'


class NyuLoader(object):
def __init__(self, args, mode):
"""mode: {'train_big', # training set used by GeoNet (CVPR18, 30907 images)
'train', # official train set (795 images)
'test'} # official test set (654 images)
"""
self.t_samples = NyuLoadPreprocess(args, mode)

# train, train_big
if 'train' in mode:
if args.distributed:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.t_samples)
else:
self.train_sampler = None

self.data = DataLoader(self.t_samples, args.batch_size,
shuffle=(self.train_sampler is None),
num_workers=args.num_threads,
pin_memory=True,
drop_last=True,
sampler=self.train_sampler)

else:
self.data = DataLoader(self.t_samples, 1,
shuffle=False,
num_workers=1,
pin_memory=False)


class NyuLoadPreprocess(Dataset):
def __init__(self, args, mode):
self.args = args
# train, train_big, test, test_new
with open("./data_split/nyu_%s.txt" % mode, 'r') as f:
self.filenames = f.readlines()
self.mode = mode
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.dataset_path = NYU_PATH
self.input_height = args.input_height
self.input_width = args.input_width

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
sample_path = self.filenames[idx]

# img path and norm path
img_path = self.dataset_path + '/' + sample_path.split()[0]
norm_path = self.dataset_path + '/' + sample_path.split()[1]
scene_name = self.mode
img_name = img_path.split('/')[-1].split('.png')[0]

# read img / normal
img = Image.open(img_path).convert("RGB").resize(size=(self.input_width, self.input_height),
resample=Image.BILINEAR)
norm_gt = Image.open(norm_path).convert("RGB").resize(size=(self.input_width, self.input_height),
resample=Image.NEAREST)

if 'train' in self.mode:
# horizontal flip (default: True)
DA_hflip = False
if self.args.data_augmentation_hflip:
DA_hflip = random.random() > 0.5
if DA_hflip:
img = TF.hflip(img)
norm_gt = TF.hflip(norm_gt)

# to array
img = np.array(img).astype(np.float32) / 255.0

norm_gt = np.array(norm_gt).astype(np.uint8)

norm_valid_mask = np.logical_not(
np.logical_and(
np.logical_and(
norm_gt[:, :, 0] == 0, norm_gt[:, :, 1] == 0),
norm_gt[:, :, 2] == 0))
norm_valid_mask = norm_valid_mask[:, :, np.newaxis]

norm_gt = ((norm_gt.astype(np.float32) / 255.0) * 2.0) - 1.0

if DA_hflip:
norm_gt[:, :, 0] = - norm_gt[:, :, 0]

# random crop (default: False)
if self.args.data_augmentation_random_crop:
img, norm_gt, norm_valid_mask = data_utils.random_crop(img, norm_gt, norm_valid_mask,
height=416, width=544)

# color augmentation (default: True)
if self.args.data_augmentation_color:
if random.random() > 0.5:
img = data_utils.color_augmentation(img, indoors=True)
else:
img = np.array(img).astype(np.float32) / 255.0

norm_gt = np.array(norm_gt).astype(np.uint8)

norm_valid_mask = np.logical_not(
np.logical_and(
np.logical_and(
norm_gt[:, :, 0] == 0, norm_gt[:, :, 1] == 0),
norm_gt[:, :, 2] == 0))
norm_valid_mask = norm_valid_mask[:, :, np.newaxis]

norm_gt = ((norm_gt.astype(np.float32) / 255.0) * 2.0) - 1.0

# to tensors
img = self.normalize(torch.from_numpy(img).permute(2, 0, 1)) # (3, H, W)
norm_gt = torch.from_numpy(norm_gt).permute(2, 0, 1) # (3, H, W)
norm_valid_mask = torch.from_numpy(norm_valid_mask).permute(2, 0, 1) # (1, H, W)

sample = {'img': img,
'norm': norm_gt,
'norm_valid_mask': norm_valid_mask,
'scene_name': scene_name,
'img_name': img_name}

return sample
39 changes: 39 additions & 0 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import random
import numpy as np


def random_crop(img, norm, norm_mask, height, width):
"""randomly crop the input image & surface normal
"""
assert img.shape[0] >= height
assert img.shape[1] >= width
x = random.randint(0, img.shape[1] - width)
y = random.randint(0, img.shape[0] - height)
img = img[y:y + height, x:x + width, :]
norm = norm[y:y + height, x:x + width, :]
norm_mask = norm_mask[y:y + height, x:x + width, :]
return img, norm, norm_mask


def color_augmentation(image, indoors=True):
"""color augmentation
"""
# gamma augmentation
gamma = random.uniform(0.9, 1.1)
image_aug = image ** gamma

# brightness augmentation
if indoors:
brightness = random.uniform(0.75, 1.25)
else:
brightness = random.uniform(0.9, 1.1)
image_aug = image_aug * brightness

# color augmentation
colors = np.random.uniform(0.9, 1.1, size=3)
white = np.ones((image.shape[0], image.shape[1]))
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
image_aug *= color_image
image_aug = np.clip(image_aug, 0, 1)
return image_aug

Loading

0 comments on commit 930abed

Please sign in to comment.