Skip to content

Commit

Permalink
Merge pull request #15 from luost26/no-cuda-ext
Browse files Browse the repository at this point in the history
Remove CUDA extensions and fix dataloader bugs
  • Loading branch information
luost26 committed Oct 7, 2021
2 parents cde2e50 + 3c55e28 commit 0bfd688
Show file tree
Hide file tree
Showing 24 changed files with 81 additions and 981 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,3 @@ dmypy.json
.DS_Store
/playgrounds
/logs*
/results*
29 changes: 17 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The official code repository for our CVPR 2021 paper "Diffusion Probabilistic Mo

## Installation

**[Step 1]** Setup conda environment
**[Option 1]** Install via conda environment YAML file (CUDA 10.1).

```bash
# Create the environment
Expand All @@ -16,20 +16,25 @@ conda env create -f env.yml
conda activate dpm-pc-gen
```

**[Step 2]** Compile the evaluation module
**[Option 2]** Or you may setup the environment manually (If you are using GPUs that only work with CUDA 11 or greater).

⚠️ Please compile the module using **`nvcc` 10.0**. Errors might occur if you use other versions (for example 10.1).
Our model only depends on the following commonly used packages, all of which can be installed via conda.

💡 You might specify your `nvcc` path [here](https://github.com/luost26/diffusion-point-cloud/blob/9be449f80b1353e6d39010363d4e139e9e532a2c/evaluation/pytorch_structural_losses/Makefile#L9).
| Package | Version |
| ------------ | -------------------------------- |
| PyTorch | ≥ 1.6.0 |
| h5py | *not specified* (we used 4.61.1) |
| tqdm | *not specified* |
| tensorboard | *not specified* (we used 2.5.0) |
| numpy | *not specified* (we used 1.20.2) |
| scipy | *not specified* (we used 1.6.2) |
| scikit-learn | *not specified* (we used 0.24.2) |

```bash
# Please ensure the conda environment `dpm-pc-gen` is activated.
cd ./evaluation/pytorch_structural_losses
make clean
make
# Return to the project directory
cd ../../
```
## About the EMD Metric

We have removed the EMD module due to GPU compatability issues. The legacy code can be found on the `emd-cd` branch.

If you have to compute the EMD score or compare our model with others, we strongly advise you to use your own code to compute the metrics. The generation and decoding results will be saved to the `results` folder after each test run.

## Datasets and Pretrained Models

Expand Down
14 changes: 0 additions & 14 deletions evaluation/README.md

This file was deleted.

79 changes: 37 additions & 42 deletions evaluation/evaluation_metrics.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
"""
From https://github.com/stevenygd/PointFlow/tree/master/metrics
"""
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from tqdm.auto import tqdm
# Import CUDA version of approximate EMD,
# from https://github.com/zekunhao1995/pcgan-pytorch/
from evaluation.StructuralLosses.match_cost import match_cost # noqa
from evaluation.StructuralLosses.nn_distance import nn_distance # noqa


def distChamferCUDA(x, y):
return nn_distance(x, y)


_EMD_NOT_IMPL_WARNED = False
def emd_approx(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
global _EMD_NOT_IMPL_WARNED
emd = torch.zeros([sample.size(0)]).to(sample)
if not _EMD_NOT_IMPL_WARNED:
_EMD_NOT_IMPL_WARNED = True
print('\n\n[WARNING]')
print(' * EMD is not implemented due to GPU compatability issue.')
print(' * We will set all EMD to zero by default.')
print(' * You may implement your own EMD in the function `emd_approx` in ./evaluation/evaluation_metrics.py')
print('\n')
return emd


# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
Expand All @@ -37,7 +38,7 @@ def distChamfer(a, b):
return P.min(1)[0], P.min(2)[0]


def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
Expand All @@ -51,10 +52,7 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]

if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch, ref_batch)
else:
dl, dr = distChamfer(sample_batch, ref_batch)
dl, dr = distChamfer(sample_batch, ref_batch)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))

emd_batch = emd_approx(sample_batch, ref_batch)
Expand All @@ -74,8 +72,7 @@ def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True):
return results


def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size,
accelerated_cd=True, verbose=True):
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
Expand All @@ -89,8 +86,8 @@ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size,
cd_lst = []
emd_lst = []
sub_iterator = range(0, N_ref, batch_size)
if verbose:
sub_iterator = tqdm(sub_iterator, leave=False)
# if verbose:
# sub_iterator = tqdm(sub_iterator, leave=False)
for ref_b_start in sub_iterator:
ref_b_end = min(N_ref, ref_b_start + batch_size)
ref_batch = ref_pcs[ref_b_start:ref_b_end]
Expand All @@ -101,10 +98,7 @@ def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size,
batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()

if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
else:
dl, dr = distChamfer(sample_batch_exp, ref_batch)
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))

emd_batch = emd_approx(sample_batch_exp, ref_batch)
Expand Down Expand Up @@ -188,40 +182,41 @@ def lgan_mmd_cov_match(all_dist):
}, min_idx.view(-1)


def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False):
def compute_all_metrics(sample_pcs, ref_pcs, batch_size):
results = {}

print("Pairwise EMD CD")
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(
ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size)

## CD
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})

res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})

## EMD
# res_emd = lgan_mmd_cov(M_rs_emd.t())
# results.update({
# "%s-EMD" % k: v for k, v in res_emd.items()
# })

for k, v in results.items():
print('[%s] %.8f' % (k, v.item()))

M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(
ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(
sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size)

# 1-NN results
## CD
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
## EMD
# one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
# results.update({
# "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
# })

return results

Expand Down
2 changes: 0 additions & 2 deletions evaluation/pytorch_structural_losses/.gitignore

This file was deleted.

103 changes: 0 additions & 103 deletions evaluation/pytorch_structural_losses/Makefile

This file was deleted.

6 changes: 0 additions & 6 deletions evaluation/pytorch_structural_losses/__init__.py

This file was deleted.

45 changes: 0 additions & 45 deletions evaluation/pytorch_structural_losses/match_cost.py

This file was deleted.

Loading

0 comments on commit 0bfd688

Please sign in to comment.