Skip to content

Commit

Permalink
add test code
Browse files Browse the repository at this point in the history
  • Loading branch information
Gwangbin Bae committed Sep 22, 2021
1 parent 8d5a2ef commit b46483e
Show file tree
Hide file tree
Showing 11 changed files with 748 additions and 6 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2021 Gwangbin Bae

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
67 changes: 61 additions & 6 deletions ReadMe.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,68 @@

# Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Official implementation of the paper

> **Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation**
>
> ICCV 2021 [oral]
>
> [Gwangbin Bae](https://baegwangbin.com), [Ignas Budvytis](https://mi.eng.cam.ac.uk/~ib255/), and [Roberto Cipolla](https://mi.eng.cam.ac.uk/~cipolla/)
> **Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation** \
> ICCV 2021 [oral] \
> [Gwangbin Bae](https://baegwangbin.com), [Ignas Budvytis](https://mi.eng.cam.ac.uk/~ib255/), and [Roberto Cipolla](https://mi.eng.cam.ac.uk/~cipolla/) \
> [[arXiv]](https://arxiv.org/abs/2109.09881)
<p align="center">
<img width=50% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_scannet.png?raw=true">
</p>

The proposed method estimates the per-pixel surface normal probability distribution, from which the expected angular error can be inferred to quantify the aleatoric uncertainty.
We also introduce a novel decoder framework where pixel-wise MLPs are trained on a subset of pixels selected based on the uncertainty.
Such uncertainty-guided sampling prevents the bias in training towards large planar surfaces, thereby improving the level of the detail in the prediction.

## Getting Started

We recommend using a virtual environment.
```
python3.6 -m venv --system-site-packages ./venv
source ./venv/bin/activate
```

Install the necessary dependencies by
```
python3.6 -m pip install -r requirements.txt
```

Download the pre-trained model weights and sample images.

```python
python download.py && cd examples && unzip examples.zip && cd ..
```

Running the above will download
* `./checkpoints/nyu.pt` (model trained on NYUv2)
* `./checkpoints/scannet.pt` (model trained on ScanNet)
* `./examples/*.png` (sample images)

## Run Demo

To test on your own images, please add them under `./examples/`. The images should be in `.png` or `.jpg`.

Test using the network trained on [NYUv2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html). We used the ground truth and data split provided by [GeoNet](https://github.com/xjqi/GeoNet).
>Please note that the ground truth for NYUv2 is only defined for the center crop of image. The prediction is therefore not accurate outside the center. When testing on your own images, we recommend using the network trained on ScanNet.
```python
python test.py --pretrained nyu --architecture GN
```

Test using the network trained on [ScanNet](http://www.scan-net.org/). We used the ground truth and data split provided by [FrameNet](https://github.com/hjwdzh/FrameNet).

```python
python test.py --pretrained scannet --architecture BN
```

Running the above will save the predicted surface normal and uncertainty under `./examples/results/`. If successful, you will obtain images like below.

<p align="center">
<img width=70% src="https://github.com/baegwangbin/surface_normal_uncertainty/blob/main/figs/readme_generalize.png?raw=true">
</p>

The predictions in the figure above are obtained by the network trained only on ScanNet. The network generalizes well to objects unseen during training (e.g., humans, cars, animals). The last row shows interesting examples where the input image only contains edges.

## Citation

Expand All @@ -22,3 +76,4 @@ If you find our work useful in your research please consider citing our paper:
year = {2021}
}
```

44 changes: 44 additions & 0 deletions data/dataloader_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import glob
import numpy as np
from PIL import Image

import torch
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class CustomLoader(object):
def __init__(self, args, fldr_path):
self.testing_samples = CustomLoadPreprocess(args, fldr_path)
self.data = DataLoader(self.testing_samples, 1,
shuffle=False,
num_workers=1,
pin_memory=False)


class CustomLoadPreprocess(Dataset):
def __init__(self, args, fldr_path):
self.fldr_path = fldr_path
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.filenames = glob.glob(self.fldr_path + '/*.png') + glob.glob(self.fldr_path + '/*.jpg')
self.input_height = args.input_height
self.input_width = args.input_width

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = self.filenames[idx]
img = Image.open(img_path).convert("RGB").resize(size=(self.input_width, self.input_height), resample=Image.BILINEAR)
img = np.array(img).astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1)
img = self.normalize(img)

img_name = img_path.split('/')[-1]
img_name = img_name.split('.png')[0] if '.png' in img_name else img_name.split('.jpg')[0]

sample = {'img': img,
'img_name': img_name}

return sample
46 changes: 46 additions & 0 deletions download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Source: https://stackoverflow.com/a/39225039
import os
import requests


def download_file_from_google_drive(id, destination):
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None

def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)

URL = "https://docs.google.com/uc?export=download"

session = requests.Session()
response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)

if token:
params = {'id': id,
'confirm': token}
response = session.get(URL, params=params, stream=True)

save_response_content(response, destination)


if __name__ == "__main__":

if not os.path.exists('./checkpoints/nyu.pt'):
print('downloading the model trained on NYUv2...')
download_file_from_google_drive('1RNiYw5rrqgBf3OkFSCSSQ67s0HMBpkAv', './checkpoints/nyu.pt')

if not os.path.exists('./checkpoints/scannet.pt'):
print('downloading the model trained on ScanNet...')
download_file_from_google_drive('1lOgY9sbMRW73qNdJze9bPkM2cmfA8Re-', './checkpoints/scannet.pt')

if not os.path.exists('./examples/examples.zip'):
print('downloading test images...')
download_file_from_google_drive('1bGZ4VFGkqrTLzQs0ELxEKo8xe_1Sfejg', './examples/examples.zip')
22 changes: 22 additions & 0 deletions models/NNET.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from models.submodules.encoder import Encoder
from models.submodules.decoder import Decoder


class NNET(nn.Module):
def __init__(self, args):
super(NNET, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder(args)

def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()

def get_10x_lr_params(self): # lr learning rate
return self.decoder.parameters()

def forward(self, img, **kwargs):
return self.decoder(self.encoder(img), **kwargs)
Loading

0 comments on commit b46483e

Please sign in to comment.