From c02414ac7377e039b6ace322cfad124548c5d8bc Mon Sep 17 00:00:00 2001 From: luost Date: Sat, 3 Jul 2021 17:38:43 +0800 Subject: [PATCH 1/5] Add env.yml --- env.yml | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 env.yml diff --git a/env.yml b/env.yml new file mode 100644 index 0000000..a465fea --- /dev/null +++ b/env.yml @@ -0,0 +1,109 @@ +name: dpm-pc-gen +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - absl-py=0.13.0=pyhd8ed1ab_0 + - aiohttp=3.7.4.post0=py37h5e8e339_0 + - async-timeout=3.0.1=py_1000 + - attrs=21.2.0=pyhd8ed1ab_0 + - blas=1.0=mkl + - blinker=1.4=py_1 + - brotlipy=0.7.0=py37h5e8e339_1001 + - c-ares=1.17.1=h7f98852_1 + - ca-certificates=2021.5.30=ha878542_0 + - cached-property=1.5.2=hd8ed1ab_1 + - cached_property=1.5.2=pyha770c72_1 + - cachetools=4.2.2=pyhd8ed1ab_0 + - certifi=2021.5.30=py37h89c1867_0 + - cffi=1.14.5=py37hc58025e_0 + - chardet=4.0.0=py37h89c1867_1 + - click=8.0.1=py37h89c1867_0 + - cryptography=3.4.7=py37h5d9358c_0 + - cudatoolkit=10.1.243=h6bb024c_0 + - dataclasses=0.8=pyhc8e2a94_1 + - freetype=2.10.4=h5ab3b9f_0 + - google-auth=1.32.0=pyh6c4a22f_0 + - google-auth-oauthlib=0.4.1=py_2 + - grpcio=1.38.1=py37hb27c1af_0 + - h5py=3.3.0=nompi_py37ha3df211_100 + - hdf5=1.10.6=nompi_h7c3c948_1111 + - idna=2.10=pyh9f0ad1d_0 + - importlib-metadata=4.6.0=py37h89c1867_0 + - intel-openmp=2021.2.0=h06a4308_610 + - jpeg=9b=h024ee3a_2 + - krb5=1.19.1=hcc1bbae_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libcurl=7.77.0=h2574ce0_0 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=h14aa051_19 + - libgfortran4=7.5.0=h14aa051_19 + - libgomp=9.3.0=h5101ec6_17 + - libnghttp2=1.43.0=h812cca2_0 + - libpng=1.6.37=hbc83047_0 + - libprotobuf=3.17.2=h780b84a_0 + - libssh2=1.9.0=ha56f1ee_6 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtiff=4.2.0=h85742a9_0 + - libwebp-base=1.2.0=h27cfd23_0 + - llvm-openmp=8.0.1=hc9558a2_0 + - lz4-c=1.9.3=h2531618_0 + - markdown=3.3.4=pyhd8ed1ab_0 + - mkl=2021.2.0=h06a4308_296 + - mkl-service=2.3.0=py37h27cfd23_1 + - mkl_fft=1.3.0=py37h42c9631_2 + - mkl_random=1.2.1=py37ha9443f7_2 + - multidict=5.1.0=py37h5e8e339_1 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=hff7bd54_1 + - numpy=1.20.2=py37h2d18471_0 + - numpy-base=1.20.2=py37hfae3a4d_0 + - oauthlib=3.1.1=pyhd8ed1ab_0 + - olefile=0.46=py37_0 + - openmp=8.0.1=0 + - openssl=1.1.1k=h7f98852_0 + - pillow=8.2.0=py37he98fc37_0 + - pip=21.1.3=py37h06a4308_0 + - point_cloud_utils=0.18.0=py37h6dcda5c_1 + - protobuf=3.17.2=py37hcd2ae1e_0 + - pyasn1=0.4.8=py_0 + - pyasn1-modules=0.2.7=py_0 + - pycparser=2.20=pyh9f0ad1d_2 + - pyjwt=2.1.0=pyhd8ed1ab_0 + - pyopenssl=20.0.1=pyhd8ed1ab_0 + - pysocks=1.7.1=py37h89c1867_3 + - python=3.7.10=h12debd9_4 + - python_abi=3.7=2_cp37m + - pytorch=1.6.0=py3.7_cuda10.1.243_cudnn7.6.3_0 + - pyu2f=0.1.5=pyhd8ed1ab_0 + - readline=8.1=h27cfd23_0 + - requests=2.25.1=pyhd3deb0d_0 + - requests-oauthlib=1.3.0=pyh9f0ad1d_0 + - rsa=4.7.2=pyh44b312d_0 + - scipy=1.6.2=py37had2a1c9_1 + - setuptools=52.0.0=py37h06a4308_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - tensorboard=2.5.0=pyhd8ed1ab_0 + - tensorboard-data-server=0.6.0=py37h7f0c10b_0 + - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 + - tk=8.6.10=hbc83047_0 + - torchvision=0.7.0=py37_cu101 + - tqdm=4.61.1=pyhd8ed1ab_0 + - typing-extensions=3.10.0.0=hd8ed1ab_0 + - typing_extensions=3.10.0.0=pyha770c72_0 + - urllib3=1.26.6=pyhd8ed1ab_0 + - werkzeug=2.0.1=pyhd8ed1ab_0 + - wheel=0.36.2=pyhd3eb1b0_0 + - xz=5.2.5=h7b6447c_0 + - yarl=1.6.3=py37h5e8e339_1 + - zipp=3.4.1=pyhd8ed1ab_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 From c9261fc8feea180043dfaf1aad81dda6e7a4b146 Mon Sep 17 00:00:00 2001 From: luost Date: Sat, 3 Jul 2021 20:58:27 +0800 Subject: [PATCH 2/5] Add evaluation module --- evaluation/.gitignore | 1 + evaluation/README.md | 14 + evaluation/__init__.py | 2 + evaluation/evaluation_metrics.py | 359 ++++++++++++++++++ .../pytorch_structural_losses/.gitignore | 2 + evaluation/pytorch_structural_losses/Makefile | 103 +++++ .../pytorch_structural_losses/__init__.py | 6 + .../pytorch_structural_losses/match_cost.py | 45 +++ .../pytorch_structural_losses/nn_distance.py | 42 ++ .../pytorch_structural_losses/pybind/bind.cpp | 15 + .../pybind/extern.hpp | 6 + evaluation/pytorch_structural_losses/setup.py | 30 ++ .../src/approxmatch.cu | 326 ++++++++++++++++ .../src/approxmatch.cuh | 8 + .../src/nndistance.cu | 155 ++++++++ .../src/nndistance.cuh | 2 + .../src/structural_loss.cpp | 125 ++++++ .../pytorch_structural_losses/src/utils.hpp | 26 ++ 18 files changed, 1267 insertions(+) create mode 100644 evaluation/.gitignore create mode 100644 evaluation/README.md create mode 100644 evaluation/__init__.py create mode 100644 evaluation/evaluation_metrics.py create mode 100644 evaluation/pytorch_structural_losses/.gitignore create mode 100644 evaluation/pytorch_structural_losses/Makefile create mode 100644 evaluation/pytorch_structural_losses/__init__.py create mode 100644 evaluation/pytorch_structural_losses/match_cost.py create mode 100644 evaluation/pytorch_structural_losses/nn_distance.py create mode 100644 evaluation/pytorch_structural_losses/pybind/bind.cpp create mode 100644 evaluation/pytorch_structural_losses/pybind/extern.hpp create mode 100644 evaluation/pytorch_structural_losses/setup.py create mode 100644 evaluation/pytorch_structural_losses/src/approxmatch.cu create mode 100644 evaluation/pytorch_structural_losses/src/approxmatch.cuh create mode 100644 evaluation/pytorch_structural_losses/src/nndistance.cu create mode 100644 evaluation/pytorch_structural_losses/src/nndistance.cuh create mode 100644 evaluation/pytorch_structural_losses/src/structural_loss.cpp create mode 100644 evaluation/pytorch_structural_losses/src/utils.hpp diff --git a/evaluation/.gitignore b/evaluation/.gitignore new file mode 100644 index 0000000..98c2a8b --- /dev/null +++ b/evaluation/.gitignore @@ -0,0 +1 @@ +StructuralLosses diff --git a/evaluation/README.md b/evaluation/README.md new file mode 100644 index 0000000..30b2c97 --- /dev/null +++ b/evaluation/README.md @@ -0,0 +1,14 @@ +# Evaluation + +From [https://github.com/stevenygd/PointFlow/tree/master/metrics](https://github.com/stevenygd/PointFlow/tree/master/metrics) + +Modifications: + +| Position | Original | Modified | +| ----------------- | ------------------------------ | ----------------------------------- | +| Makefile:9 | `/usr/local/cuda/bin/nvcc` | `/usr/local/cuda-10.0/bin/nvcc` | +| Makefile:69,70 | `c++11` | `c++14` | +| Makefile:74,75 | `lib.linux-x86_64-3.6` | `lib.linux-x86_64-3.7` | +| Pybind/bind.cpp:5 | `#include "pybind/extern.hpp"` | `#include "extern.hpp"` | +| \__init__.py | | `from .evaluation_metrics import *` | + diff --git a/evaluation/__init__.py b/evaluation/__init__.py new file mode 100644 index 0000000..195e455 --- /dev/null +++ b/evaluation/__init__.py @@ -0,0 +1,2 @@ +from .evaluation_metrics import * +from .evaluation_metrics import _pairwise_EMD_CD_, _jsdiv diff --git a/evaluation/evaluation_metrics.py b/evaluation/evaluation_metrics.py new file mode 100644 index 0000000..42d75ba --- /dev/null +++ b/evaluation/evaluation_metrics.py @@ -0,0 +1,359 @@ +import torch +import numpy as np +import warnings +from scipy.stats import entropy +from sklearn.neighbors import NearestNeighbors +from numpy.linalg import norm +from tqdm.auto import tqdm +# Import CUDA version of approximate EMD, +# from https://github.com/zekunhao1995/pcgan-pytorch/ +from evaluation.StructuralLosses.match_cost import match_cost # noqa +from evaluation.StructuralLosses.nn_distance import nn_distance # noqa + + +def distChamferCUDA(x, y): + return nn_distance(x, y) + + +def emd_approx(sample, ref): + B, N, N_ref = sample.size(0), sample.size(1), ref.size(1) + assert N == N_ref, "Not sure what would EMD do in this case" + emd = match_cost(sample, ref) # (B,) + emd_norm = emd / float(N) # (B,) + return emd_norm + + +# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet +def distChamfer(a, b): + x, y = a, b + bs, num_points, points_dim = x.size() + xx = torch.bmm(x, x.transpose(2, 1)) + yy = torch.bmm(y, y.transpose(2, 1)) + zz = torch.bmm(x, y.transpose(2, 1)) + diag_ind = torch.arange(0, num_points).to(a).long() + rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) + ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) + P = (rx.transpose(2, 1) + ry - 2 * zz) + return P.min(1)[0], P.min(2)[0] + + +def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) + + cd_lst = [] + emd_lst = [] + iterator = range(0, N_sample, batch_size) + + for b_start in tqdm(iterator, desc='EMD-CD'): + b_end = min(N_sample, b_start + batch_size) + sample_batch = sample_pcs[b_start:b_end] + ref_batch = ref_pcs[b_start:b_end] + + if accelerated_cd: + dl, dr = distChamferCUDA(sample_batch, ref_batch) + else: + dl, dr = distChamfer(sample_batch, ref_batch) + cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) + + emd_batch = emd_approx(sample_batch, ref_batch) + emd_lst.append(emd_batch) + + if reduced: + cd = torch.cat(cd_lst).mean() + emd = torch.cat(emd_lst).mean() + else: + cd = torch.cat(cd_lst) + emd = torch.cat(emd_lst) + + results = { + 'MMD-CD': cd, + 'MMD-EMD': emd, + } + return results + + +def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, + accelerated_cd=True, verbose=True): + N_sample = sample_pcs.shape[0] + N_ref = ref_pcs.shape[0] + all_cd = [] + all_emd = [] + iterator = range(N_sample) + if verbose: + iterator = tqdm(iterator, desc='Pairwise EMD-CD') + for sample_b_start in iterator: + sample_batch = sample_pcs[sample_b_start] + + cd_lst = [] + emd_lst = [] + sub_iterator = range(0, N_ref, batch_size) + if verbose: + sub_iterator = tqdm(sub_iterator, leave=False) + for ref_b_start in sub_iterator: + ref_b_end = min(N_ref, ref_b_start + batch_size) + ref_batch = ref_pcs[ref_b_start:ref_b_end] + + batch_size_ref = ref_batch.size(0) + point_dim = ref_batch.size(2) + sample_batch_exp = sample_batch.view(1, -1, point_dim).expand( + batch_size_ref, -1, -1) + sample_batch_exp = sample_batch_exp.contiguous() + + if accelerated_cd: + dl, dr = distChamferCUDA(sample_batch_exp, ref_batch) + else: + dl, dr = distChamfer(sample_batch_exp, ref_batch) + cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) + + emd_batch = emd_approx(sample_batch_exp, ref_batch) + emd_lst.append(emd_batch.view(1, -1)) + + cd_lst = torch.cat(cd_lst, dim=1) + emd_lst = torch.cat(emd_lst, dim=1) + all_cd.append(cd_lst) + all_emd.append(emd_lst) + + all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref + all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref + + return all_cd, all_emd + + +# Adapted from https://github.com/xuqiantong/ +# GAN-Metrics/blob/master/framework/metric.py +def knn(Mxx, Mxy, Myy, k, sqrt=False): + n0 = Mxx.size(0) + n1 = Myy.size(0) + label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) + M = torch.cat([ + torch.cat((Mxx, Mxy), 1), + torch.cat((Mxy.transpose(0, 1), Myy), 1)], 0) + if sqrt: + M = M.abs().sqrt() + INFINITY = float('inf') + val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk( + k, 0, False) + + count = torch.zeros(n0 + n1).to(Mxx) + for i in range(0, k): + count = count + label.index_select(0, idx[i]) + pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() + + s = { + 'tp': (pred * label).sum(), + 'fp': (pred * (1 - label)).sum(), + 'fn': ((1 - pred) * label).sum(), + 'tn': ((1 - pred) * (1 - label)).sum(), + } + + s.update({ + 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), + 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), + 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), + 'acc': torch.eq(label, pred).float().mean(), + }) + return s + + +def lgan_mmd_cov(all_dist): + N_sample, N_ref = all_dist.size(0), all_dist.size(1) + min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) + min_val, _ = torch.min(all_dist, dim=0) + mmd = min_val.mean() + mmd_smp = min_val_fromsmp.mean() + cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) + cov = torch.tensor(cov).to(all_dist) + return { + 'lgan_mmd': mmd, + 'lgan_cov': cov, + 'lgan_mmd_smp': mmd_smp, + } + + +def lgan_mmd_cov_match(all_dist): + N_sample, N_ref = all_dist.size(0), all_dist.size(1) + min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) + min_val, _ = torch.min(all_dist, dim=0) + mmd = min_val.mean() + mmd_smp = min_val_fromsmp.mean() + cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) + cov = torch.tensor(cov).to(all_dist) + return { + 'lgan_mmd': mmd, + 'lgan_cov': cov, + 'lgan_mmd_smp': mmd_smp, + }, min_idx.view(-1) + + +def compute_all_metrics(sample_pcs, ref_pcs, batch_size, accelerated_cd=False): + results = {} + + print("Pairwise EMD CD") + M_rs_cd, M_rs_emd = _pairwise_EMD_CD_( + ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) + + res_cd = lgan_mmd_cov(M_rs_cd.t()) + results.update({ + "%s-CD" % k: v for k, v in res_cd.items() + }) + + res_emd = lgan_mmd_cov(M_rs_emd.t()) + results.update({ + "%s-EMD" % k: v for k, v in res_emd.items() + }) + + for k, v in results.items(): + print('[%s] %.8f' % (k, v.item())) + + M_rr_cd, M_rr_emd = _pairwise_EMD_CD_( + ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd) + M_ss_cd, M_ss_emd = _pairwise_EMD_CD_( + sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd) + + # 1-NN results + one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) + results.update({ + "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k + }) + one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) + results.update({ + "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k + }) + + return results + + +####################################################### +# JSD : from https://github.com/optas/latent_3d_points +####################################################### +def unit_cube_grid_point_cloud(resolution, clip_sphere=False): + """Returns the center coordinates of each cell of a 3D grid with + resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True + it drops the "corner" cells that lie outside the unit-sphere. + """ + grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) + spacing = 1.0 / float(resolution - 1) + for i in range(resolution): + for j in range(resolution): + for k in range(resolution): + grid[i, j, k, 0] = i * spacing - 0.5 + grid[i, j, k, 1] = j * spacing - 0.5 + grid[i, j, k, 2] = k * spacing - 0.5 + + if clip_sphere: + grid = grid.reshape(-1, 3) + grid = grid[norm(grid, axis=1) <= 0.5] + + return grid, spacing + + +def jsd_between_point_cloud_sets( + sample_pcs, ref_pcs, resolution=28): + """Computes the JSD between two sets of point-clouds, + as introduced in the paper + ```Learning Representations And Generative Models For 3D Point Clouds```. + Args: + sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. + ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. + resolution: (int) grid-resolution. Affects granularity of measurements. + """ + in_unit_sphere = True + sample_grid_var = entropy_of_occupancy_grid( + sample_pcs, resolution, in_unit_sphere)[1] + ref_grid_var = entropy_of_occupancy_grid( + ref_pcs, resolution, in_unit_sphere)[1] + return jensen_shannon_divergence(sample_grid_var, ref_grid_var) + + +def entropy_of_occupancy_grid( + pclouds, grid_resolution, in_sphere=False, verbose=False): + """Given a collection of point-clouds, estimate the entropy of + the random variables corresponding to occupancy-grid activation patterns. + Inputs: + pclouds: (numpy array) #point-clouds x points per point-cloud x 3 + grid_resolution (int) size of occupancy grid that will be used. + """ + epsilon = 10e-4 + bound = 0.5 + epsilon + if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit cube.') + + if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: + if verbose: + warnings.warn('Point-clouds are not in unit sphere.') + + grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) + grid_coordinates = grid_coordinates.reshape(-1, 3) + grid_counters = np.zeros(len(grid_coordinates)) + grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) + nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) + + for pc in tqdm(pclouds, desc='JSD'): + _, indices = nn.kneighbors(pc) + indices = np.squeeze(indices) + for i in indices: + grid_counters[i] += 1 + indices = np.unique(indices) + for i in indices: + grid_bernoulli_rvars[i] += 1 + + acc_entropy = 0.0 + n = float(len(pclouds)) + for g in grid_bernoulli_rvars: + if g > 0: + p = float(g) / n + acc_entropy += entropy([p, 1.0 - p]) + + return acc_entropy / len(grid_counters), grid_counters + + +def jensen_shannon_divergence(P, Q): + if np.any(P < 0) or np.any(Q < 0): + raise ValueError('Negative values.') + if len(P) != len(Q): + raise ValueError('Non equal size.') + + P_ = P / np.sum(P) # Ensure probabilities. + Q_ = Q / np.sum(Q) + + e1 = entropy(P_, base=2) + e2 = entropy(Q_, base=2) + e_sum = entropy((P_ + Q_) / 2.0, base=2) + res = e_sum - ((e1 + e2) / 2.0) + + res2 = _jsdiv(P_, Q_) + + if not np.allclose(res, res2, atol=10e-5, rtol=0): + warnings.warn('Numerical values of two JSD methods don\'t agree.') + + return res + + +def _jsdiv(P, Q): + """another way of computing JSD""" + + def _kldiv(A, B): + a = A.copy() + b = B.copy() + idx = np.logical_and(a > 0, b > 0) + a = a[idx] + b = b[idx] + return np.sum([v for v in a * np.log2(a / b)]) + + P_ = P / np.sum(P) + Q_ = Q / np.sum(Q) + + M = 0.5 * (P_ + Q_) + + return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) + + +if __name__ == '__main__': + a = torch.randn([16, 2048, 3]).cuda() + b = torch.randn([16, 2048, 3]).cuda() + print(EMD_CD(a, b, batch_size=8)) + \ No newline at end of file diff --git a/evaluation/pytorch_structural_losses/.gitignore b/evaluation/pytorch_structural_losses/.gitignore new file mode 100644 index 0000000..2c49087 --- /dev/null +++ b/evaluation/pytorch_structural_losses/.gitignore @@ -0,0 +1,2 @@ +PyTorchStructuralLosses.egg-info/ +objs/ \ No newline at end of file diff --git a/evaluation/pytorch_structural_losses/Makefile b/evaluation/pytorch_structural_losses/Makefile new file mode 100644 index 0000000..0064383 --- /dev/null +++ b/evaluation/pytorch_structural_losses/Makefile @@ -0,0 +1,103 @@ +############################################################################### +# Uncomment for debugging +# DEBUG := 1 +# Pretty build +# Q ?= @ + +CXX := g++ +PYTHON := python +NVCC := /usr/local/cuda-10.0/bin/nvcc # !!WARNING!! Errors occur when using version 10.1 + +# PYTHON Header path +PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') +PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') +PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') + +# CUDA ROOT DIR that contains bin/ lib64/ and include/ +# CUDA_DIR := /usr/local/cuda +CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') + +INCLUDE_DIRS := ./ $(CUDA_DIR)/include + +INCLUDE_DIRS += $(PYTHON_HEADER_DIR) +INCLUDE_DIRS += $(PYTORCH_INCLUDES) + +# Custom (MKL/ATLAS/OpenBLAS) include and lib directories. +# Leave commented to accept the defaults for your choice of BLAS +# (which should work)! +# BLAS_INCLUDE := /path/to/your/blas +# BLAS_LIB := /path/to/your/blas + +############################################################################### +SRC_DIR := ./src +OBJ_DIR := ./objs +CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) +CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) +OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) +CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) +STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a + +# CUDA architecture setting: going with all of them. +# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. +# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. +# CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ +# -gencode arch=compute_61,code=compute_61 \ +# -gencode arch=compute_52,code=sm_52 +CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \ + -gencode arch=compute_61,code=compute_61 + + +# We will also explicitly add stdc++ to the link target. +LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu + +# Debugging +ifeq ($(DEBUG), 1) + COMMON_FLAGS += -DDEBUG -g -O0 + # https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/ + NVCCFLAGS += -g -G # -rdc true +else + COMMON_FLAGS += -DNDEBUG -O3 +endif + +WARNINGS := -Wall -Wno-sign-compare -Wcomment + +INCLUDE_DIRS += $(BLAS_INCLUDE) + +# Automatic dependency generation (nvcc is handled separately) +CXXFLAGS += -MMD -MP + +# Complete build flags. +COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ + -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0 +CXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS) +NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) + +all: $(STATIC_LIB) + $(PYTHON) setup.py build + @ mv build/lib.linux-x86_64-3.7/StructuralLosses .. + @ mv build/lib.linux-x86_64-3.7/*.so ../StructuralLosses/ + @- $(RM) -rf $(OBJ_DIR) build objs + +$(OBJ_DIR): + @ mkdir -p $@ + @ mkdir -p $@/cuda + +$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) + @ echo CXX $< + $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ + +$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) + @ echo NVCC $< + $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ + -odir $(@D) + $(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ + +$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) + $(RM) -f $(STATIC_LIB) + $(RM) -rf build dist + @ echo LD -o $@ + ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) + +clean: + @- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses + diff --git a/evaluation/pytorch_structural_losses/__init__.py b/evaluation/pytorch_structural_losses/__init__.py new file mode 100644 index 0000000..1656c10 --- /dev/null +++ b/evaluation/pytorch_structural_losses/__init__.py @@ -0,0 +1,6 @@ +#import torch + +#from MakePytorchBackend import AddGPU, Foo, ApproxMatch + +#from Add import add_gpu, approx_match + diff --git a/evaluation/pytorch_structural_losses/match_cost.py b/evaluation/pytorch_structural_losses/match_cost.py new file mode 100644 index 0000000..f2c1885 --- /dev/null +++ b/evaluation/pytorch_structural_losses/match_cost.py @@ -0,0 +1,45 @@ +import torch +from torch.autograd import Function +from evaluation.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad + +# Inherit from Function +class MatchCostFunction(Function): + # Note that both forward and backward are @staticmethods + @staticmethod + # bias is an optional argument + def forward(ctx, seta, setb): + #print("Match Cost Forward") + ctx.save_for_backward(seta, setb) + ''' + input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 + returns: + match : batch_size * #query_points * #dataset_points + ''' + match, temp = ApproxMatch(seta, setb) + ctx.match = match + cost = MatchCost(seta, setb, match) + return cost + + """ + grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match) + return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None] + """ + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_output): + #print("Match Cost Backward") + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + seta, setb = ctx.saved_tensors + #grad_input = grad_weight = grad_bias = None + grada, gradb = MatchCostGrad(seta, setb, ctx.match) + grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2) + return grada*grad_output_expand, gradb*grad_output_expand + +match_cost = MatchCostFunction.apply + diff --git a/evaluation/pytorch_structural_losses/nn_distance.py b/evaluation/pytorch_structural_losses/nn_distance.py new file mode 100644 index 0000000..a100a26 --- /dev/null +++ b/evaluation/pytorch_structural_losses/nn_distance.py @@ -0,0 +1,42 @@ +import torch +from torch.autograd import Function +# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad +from evaluation.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad + +# Inherit from Function +class NNDistanceFunction(Function): + # Note that both forward and backward are @staticmethods + @staticmethod + # bias is an optional argument + def forward(ctx, seta, setb): + #print("Match Cost Forward") + ctx.save_for_backward(seta, setb) + ''' + input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 + returns: + dist1, idx1, dist2, idx2 + ''' + dist1, idx1, dist2, idx2 = NNDistance(seta, setb) + ctx.idx1 = idx1 + ctx.idx2 = idx2 + return dist1, dist2 + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_dist1, grad_dist2): + #print("Match Cost Backward") + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + seta, setb = ctx.saved_tensors + idx1 = ctx.idx1 + idx2 = ctx.idx2 + grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2) + return grada, gradb + +nn_distance = NNDistanceFunction.apply + diff --git a/evaluation/pytorch_structural_losses/pybind/bind.cpp b/evaluation/pytorch_structural_losses/pybind/bind.cpp new file mode 100644 index 0000000..e80c835 --- /dev/null +++ b/evaluation/pytorch_structural_losses/pybind/bind.cpp @@ -0,0 +1,15 @@ +#include + +#include + +#include "extern.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ + m.def("ApproxMatch", &ApproxMatch); + m.def("MatchCost", &MatchCost); + m.def("MatchCostGrad", &MatchCostGrad); + m.def("NNDistance", &NNDistance); + m.def("NNDistanceGrad", &NNDistanceGrad); +} diff --git a/evaluation/pytorch_structural_losses/pybind/extern.hpp b/evaluation/pytorch_structural_losses/pybind/extern.hpp new file mode 100644 index 0000000..003877b --- /dev/null +++ b/evaluation/pytorch_structural_losses/pybind/extern.hpp @@ -0,0 +1,6 @@ +std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b); +at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match); +std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match); + +std::vector NNDistance(at::Tensor set_d, at::Tensor set_q); +std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2); diff --git a/evaluation/pytorch_structural_losses/setup.py b/evaluation/pytorch_structural_losses/setup.py new file mode 100644 index 0000000..67f0e8c --- /dev/null +++ b/evaluation/pytorch_structural_losses/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension + +# Python interface +setup( + name='PyTorchStructuralLosses', + version='0.1.0', + install_requires=['torch'], + packages=['StructuralLosses'], + package_dir={'StructuralLosses': './'}, + ext_modules=[ + CUDAExtension( + name='StructuralLossesBackend', + include_dirs=['./'], + sources=[ + 'pybind/bind.cpp', + ], + libraries=['make_pytorch'], + library_dirs=['objs'], + # extra_compile_args=['-g'] + ) + ], + cmdclass={'build_ext': BuildExtension}, + author='Christopher B. Choy', + author_email='chrischoy@ai.stanford.edu', + description='Tutorial for Pytorch C++ Extension with a Makefile', + keywords='Pytorch C++ Extension', + url='https://github.com/chrischoy/MakePytorchPlusPlus', + zip_safe=False, +) diff --git a/evaluation/pytorch_structural_losses/src/approxmatch.cu b/evaluation/pytorch_structural_losses/src/approxmatch.cu new file mode 100644 index 0000000..42058be --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/approxmatch.cu @@ -0,0 +1,326 @@ +#include "utils.hpp" + +__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ + float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; + float multiL,multiR; + if (n>=m){ + multiL=1; + multiR=n/m; + }else{ + multiL=m/n; + multiR=1; + } + const int Block=1024; + __shared__ float buf[Block*4]; + for (int i=blockIdx.x;i=-2;j--){ + for (int j=7;j>-2;j--){ + float level=-powf(4.0f,j); + if (j==-2){ + level=0; + } + for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out); +//} + +__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ + __shared__ float sum_grad[256*3]; + for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2); +//} + +/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, + cudaStream_t stream)*/ +// temp: TensorShape{b,(n+m)*2} +void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){ + approxmatchkernel + <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} + +void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){ + matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} + +void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){ + matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1); + matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error(Formatter() + << "CUDA kernel failed : " << std::to_string(err)); +} diff --git a/evaluation/pytorch_structural_losses/src/approxmatch.cuh b/evaluation/pytorch_structural_losses/src/approxmatch.cuh new file mode 100644 index 0000000..440d64d --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/approxmatch.cuh @@ -0,0 +1,8 @@ +/* +template +void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N, + cudaStream_t stream); +*/ +void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream); +void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream); +void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream); diff --git a/evaluation/pytorch_structural_losses/src/nndistance.cu b/evaluation/pytorch_structural_losses/src/nndistance.cu new file mode 100644 index 0000000..bd13b8b --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/nndistance.cu @@ -0,0 +1,155 @@ + +__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ + const int batch=512; + __shared__ float buf[batch*3]; + for (int i=blockIdx.x;ibest){ + result[(i*n+j)]=best; + result_i[(i*n+j)]=best_i; + } + } + __syncthreads(); + } + } +} +void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ + NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i); + NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); +} +__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ + for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); + NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); +} + diff --git a/evaluation/pytorch_structural_losses/src/nndistance.cuh b/evaluation/pytorch_structural_losses/src/nndistance.cuh new file mode 100644 index 0000000..e2b65c3 --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/nndistance.cuh @@ -0,0 +1,2 @@ +void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); +void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); diff --git a/evaluation/pytorch_structural_losses/src/structural_loss.cpp b/evaluation/pytorch_structural_losses/src/structural_loss.cpp new file mode 100644 index 0000000..f58702c --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/structural_loss.cpp @@ -0,0 +1,125 @@ +#include +#include + +#include "src/approxmatch.cuh" +#include "src/nndistance.cuh" + +#include +#include + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +/* +input: + set1 : batch_size * #dataset_points * 3 + set2 : batch_size * #query_points * 3 +returns: + match : batch_size * #query_points * #dataset_points +*/ +// temp: TensorShape{b,(n+m)*2} +std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) { + //std::cout << "[ApproxMatch] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl; + at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(temp); + + approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream()); + return {match, temp}; +} + +at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { + //std::cout << "[MatchCost] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl; + at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(out); + matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream()); + return out; +} + +std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) { + //std::cout << "[MatchCostGrad] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl; + at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(match); + CHECK_INPUT(grad1); + CHECK_INPUT(grad2); + matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream()); + return {grad1, grad2}; +} + + +/* +input: + set_d : batch_size * #dataset_points * 3 + set_q : batch_size * #query_points * 3 +returns: + dist1, idx1 : batch_size * #dataset_points + dist2, idx2 : batch_size * #query_points +*/ +std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) { + //std::cout << "[NNDistance] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl; + at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); + at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(dist1); + CHECK_INPUT(idx1); + CHECK_INPUT(dist2); + CHECK_INPUT(idx2); + // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream); + nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream()); + return {dist1, idx1, dist2, idx2}; +} + +std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) { + //std::cout << "[NNDistanceGrad] Called." << std::endl; + int64_t batch_size = set_d.size(0); + int64_t n_dataset_points = set_d.size(1); // n + int64_t n_query_points = set_q.size(1); // m + //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl; + at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device())); + CHECK_INPUT(set_d); + CHECK_INPUT(set_q); + CHECK_INPUT(idx1); + CHECK_INPUT(idx2); + CHECK_INPUT(grad_dist1); + CHECK_INPUT(grad_dist2); + CHECK_INPUT(grad1); + CHECK_INPUT(grad2); + //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream); + nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(), + grad_dist1.data(),idx1.data(), + grad_dist2.data(),idx2.data(), + grad1.data(),grad2.data(), + at::cuda::getCurrentCUDAStream()); + return {grad1, grad2}; +} + diff --git a/evaluation/pytorch_structural_losses/src/utils.hpp b/evaluation/pytorch_structural_losses/src/utils.hpp new file mode 100644 index 0000000..d60fa2b --- /dev/null +++ b/evaluation/pytorch_structural_losses/src/utils.hpp @@ -0,0 +1,26 @@ +#include +#include +#include + +class Formatter { +public: + Formatter() {} + ~Formatter() {} + + template Formatter &operator<<(const Type &value) { + stream_ << value; + return *this; + } + + std::string str() const { return stream_.str(); } + operator std::string() const { return stream_.str(); } + + enum ConvertToString { to_str }; + + std::string operator>>(ConvertToString) { return stream_.str(); } + +private: + std::stringstream stream_; + Formatter(const Formatter &); + Formatter &operator=(Formatter &); +}; From a8dacfc83671acd8b91e4c7d4143aea74017bdc5 Mon Sep 17 00:00:00 2001 From: luost Date: Sat, 3 Jul 2021 21:30:38 +0800 Subject: [PATCH 3/5] Update .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 7f6b817..ab23008 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,6 @@ dmypy.json .pyre/ .DS_Store +/playgrounds +/logs* +/results* From 5fd3f798a591d3bf5675dfe89cd53a2c0320e33f Mon Sep 17 00:00:00 2001 From: luost Date: Sat, 3 Jul 2021 21:30:50 +0800 Subject: [PATCH 4/5] Update env.yml --- env.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/env.yml b/env.yml index a465fea..de769f8 100644 --- a/env.yml +++ b/env.yml @@ -34,10 +34,13 @@ dependencies: - idna=2.10=pyh9f0ad1d_0 - importlib-metadata=4.6.0=py37h89c1867_0 - intel-openmp=2021.2.0=h06a4308_610 + - joblib=1.0.1=pyhd8ed1ab_0 - jpeg=9b=h024ee3a_2 - krb5=1.19.1=hcc1bbae_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.35.1=h7274673_9 + - libblas=3.9.0=9_mkl + - libcblas=3.9.0=9_mkl - libcurl=7.77.0=h2574ce0_0 - libedit=3.1.20191231=he28a2e2_2 - libev=4.33=h516909a_1 @@ -87,6 +90,7 @@ dependencies: - requests=2.25.1=pyhd3deb0d_0 - requests-oauthlib=1.3.0=pyh9f0ad1d_0 - rsa=4.7.2=pyh44b312d_0 + - scikit-learn=0.24.2=py37h18a542f_0 - scipy=1.6.2=py37had2a1c9_1 - setuptools=52.0.0=py37h06a4308_0 - six=1.16.0=pyhd3eb1b0_0 @@ -94,6 +98,7 @@ dependencies: - tensorboard=2.5.0=pyhd8ed1ab_0 - tensorboard-data-server=0.6.0=py37h7f0c10b_0 - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 + - threadpoolctl=2.1.0=pyh5ca1d4c_0 - tk=8.6.10=hbc83047_0 - torchvision=0.7.0=py37_cu101 - tqdm=4.61.1=pyhd8ed1ab_0 From b3013ef7ef3cfc96db350490944a949021420078 Mon Sep 17 00:00:00 2001 From: luost Date: Sat, 3 Jul 2021 21:31:08 +0800 Subject: [PATCH 5/5] Upload --- data/.gitignore | 2 + models/vae_flow.py | 2 +- pretrained/.gitignore | 2 + test_ae.py | 81 ++++++++++++ test_gen.py | 115 +++++++++++++++++ train_ae.py | 212 ++++++++++++++++++++++++++++++ train_gen.py | 222 ++++++++++++++++++++++++++++++++ utils/data.py | 34 +++++ utils/dataset.py | 145 +++++++++++++++++++++ utils/misc.py | 162 +++++++++++++++++++++++ utils/transform.py | 292 ++++++++++++++++++++++++++++++++++++++++++ 11 files changed, 1268 insertions(+), 1 deletion(-) create mode 100644 data/.gitignore create mode 100644 pretrained/.gitignore create mode 100644 test_ae.py create mode 100644 test_gen.py create mode 100644 train_ae.py create mode 100644 train_gen.py create mode 100644 utils/data.py create mode 100644 utils/dataset.py create mode 100644 utils/misc.py create mode 100644 utils/transform.py diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/models/vae_flow.py b/models/vae_flow.py index a632c3e..d3eddd7 100644 --- a/models/vae_flow.py +++ b/models/vae_flow.py @@ -24,7 +24,7 @@ def __init__(self, args): ) ) - def get_loss(self, x, writer=None, it=None, kl_weight=1.0): + def get_loss(self, x, kl_weight, writer=None, it=None): """ Args: x: Input point clouds, (B, N, d). diff --git a/pretrained/.gitignore b/pretrained/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/pretrained/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/test_ae.py b/test_ae.py new file mode 100644 index 0000000..a772c43 --- /dev/null +++ b/test_ae.py @@ -0,0 +1,81 @@ +import os +import time +import argparse +import torch +from tqdm.auto import tqdm + +from utils.dataset import * +from utils.misc import * +from utils.data import * +from models.autoencoder import * +from evaluation import EMD_CD + + +# Arguments +parser = argparse.ArgumentParser() +parser.add_argument('--ckpt', type=str, default='./pretrained/AE_airplane.pt') +parser.add_argument('--categories', type=str_list, default=['airplane']) +parser.add_argument('--save_dir', type=str, default='./results') +parser.add_argument('--device', type=str, default='cuda') +# Datasets and loaders +parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') +parser.add_argument('--batch_size', type=int, default=128) +parser.add_argument('--num_workers', type=int, default=4) +args = parser.parse_args() + +# Logging +save_dir = os.path.join(args.save_dir, 'AE_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) ) +if not os.path.exists(save_dir): + os.makedirs(save_dir) +logger = get_logger('test', save_dir) +for k, v in vars(args).items(): + logger.info('[ARGS::%s] %s' % (k, repr(v))) + +# Checkpoint +ckpt = torch.load(args.ckpt) +seed_all(ckpt['args'].seed) + +# Datasets and loaders +logger.info('Loading datasets...') +test_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='test', + scale_mode=ckpt['args'].scale_mode +) +test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers) + +# Model +logger.info('Loading model...') +model = AutoEncoder(ckpt['args']).to(args.device) +model.load_state_dict(ckpt['state_dict']) + +all_ref = [] +all_recons = [] +for i, batch in enumerate(tqdm(test_loader)): + ref = batch['pointcloud'].to(args.device) + shift = batch['shift'].to(args.device) + scale = batch['scale'].to(args.device) + model.eval() + with torch.no_grad(): + code = model.encode(ref) + recons = model.decode(code, ref.size(1), flexibility=ckpt['args'].flexibility).detach() + + ref = ref * scale + shift + recons = recons * scale + shift + + all_ref.append(ref.detach().cpu()) + all_recons.append(recons.detach().cpu()) + +all_ref = torch.cat(all_ref, dim=0) +all_recons = torch.cat(all_recons, dim=0) + +logger.info('Saving point clouds...') +np.save(os.path.join(save_dir, 'ref.npy'), all_ref.numpy()) +np.save(os.path.join(save_dir, 'out.npy'), all_recons.numpy()) + +logger.info('Start computing metrics...') +metrics = EMD_CD(all_recons.to(args.device), all_ref.to(args.device), batch_size=args.batch_size, accelerated_cd=True) +cd, emd = metrics['MMD-CD'].item(), metrics['MMD-EMD'].item() +logger.info('CD: %.12f' % cd) +logger.info('EMD: %.12f' % emd) diff --git a/test_gen.py b/test_gen.py new file mode 100644 index 0000000..840c118 --- /dev/null +++ b/test_gen.py @@ -0,0 +1,115 @@ +import os +import time +import math +import argparse +import torch +from tqdm.auto import tqdm + +from utils.dataset import * +from utils.misc import * +from utils.data import * +from models.vae_gaussian import * +from models.vae_flow import * +from models.flow import add_spectral_norm, spectral_norm_power_iteration +from evaluation import * + +def normalize_point_clouds(pcs, mode, logger): + if mode is None: + logger.info('Will not normalize point clouds.') + return pcs + logger.info('Normalization mode: %s' % mode) + for i in tqdm(range(pcs.size(0)), desc='Normalize'): + pc = pcs[i] + if mode == 'shape_unit': + shift = pc.mean(dim=0).reshape(1, 3) + scale = pc.flatten().std().reshape(1, 1) + elif mode == 'shape_bbox': + pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) + pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) + shift = ((pc_min + pc_max) / 2).view(1, 3) + scale = (pc_max - pc_min).max().reshape(1, 1) / 2 + pc = (pc - shift) / scale + pcs[i] = pc + return pcs + + +# Arguments +parser = argparse.ArgumentParser() +parser.add_argument('--ckpt', type=str, default='./pretrained/GEN_airplane.pt') +parser.add_argument('--categories', type=str_list, default=['airplane']) +parser.add_argument('--save_dir', type=str, default='./results') +parser.add_argument('--device', type=str, default='cuda') +# Datasets and loaders +parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') +parser.add_argument('--batch_size', type=int, default=128) +parser.add_argument('--num_workers', type=int, default=4) +# Sampling +parser.add_argument('--sample_num_points', type=int, default=2048) +parser.add_argument('--normalize', type=str, default='shape_bbox', choices=[None, 'shape_unit', 'shape_bbox']) +parser.add_argument('--seed', type=int, default=9988) +args = parser.parse_args() + + +# Logging +save_dir = os.path.join(args.save_dir, 'GEN_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) ) +if not os.path.exists(save_dir): + os.makedirs(save_dir) +logger = get_logger('test', save_dir) +for k, v in vars(args).items(): + logger.info('[ARGS::%s] %s' % (k, repr(v))) + +# Checkpoint +ckpt = torch.load(args.ckpt) +seed_all(args.seed) + +# Datasets and loaders +logger.info('Loading datasets...') +test_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='test', + scale_mode=args.normalize, +) +test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=args.num_workers) + +# Model +logger.info('Loading model...') +if ckpt['args'].model == 'gaussian': + model = GaussianVAE(ckpt['args']).to(args.device) +elif ckpt['args'].model == 'flow': + model = FlowVAE(ckpt['args']).to(args.device) +logger.info(repr(model)) +# if ckpt['args'].spectral_norm: +# add_spectral_norm(model, logger=logger) +model.load_state_dict(ckpt['state_dict']) + +# Reference Point Clouds +ref_pcs = [] +for i, data in enumerate(test_dset): + ref_pcs.append(data['pointcloud'].unsqueeze(0)) +ref_pcs = torch.cat(ref_pcs, dim=0) + +# Generate Point Clouds +gen_pcs = [] +for i in tqdm(range(0, math.ceil(len(test_dset) / args.batch_size)), 'Generate'): + with torch.no_grad(): + z = torch.randn([args.batch_size, ckpt['args'].latent_dim]).to(args.device) + x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility) + gen_pcs.append(x.detach().cpu()) +gen_pcs = torch.cat(gen_pcs, dim=0)[:len(test_dset)] +if args.normalize is not None: + gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger) + +# Save +logger.info('Saving point clouds...') +np.save(os.path.join(save_dir, 'out.npy'), gen_pcs.numpy()) + +# Compute metrics +with torch.no_grad(): + results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.batch_size, accelerated_cd=True) + results = {k:v.item() for k, v in results.items()} + jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy()) + results['jsd'] = jsd + +for k, v in results.items(): + logger.info('%s: %.12f' % (k, v)) diff --git a/train_ae.py b/train_ae.py new file mode 100644 index 0000000..417786d --- /dev/null +++ b/train_ae.py @@ -0,0 +1,212 @@ +import os +import argparse +import torch +import torch.utils.tensorboard +from torch.nn.utils import clip_grad_norm_ +from tqdm.auto import tqdm + +from utils.dataset import * +from utils.misc import * +from utils.data import * +from utils.transform import * +from models.autoencoder import * +from evaluation import EMD_CD + + +# Arguments +parser = argparse.ArgumentParser() +# Model arguments +parser.add_argument('--latent_dim', type=int, default=256) +parser.add_argument('--num_steps', type=int, default=200) +parser.add_argument('--beta_1', type=float, default=1e-4) +parser.add_argument('--beta_T', type=float, default=0.05) +parser.add_argument('--sched_mode', type=str, default='linear') +parser.add_argument('--flexibility', type=float, default=0.0) +parser.add_argument('--residual', type=eval, default=True, choices=[True, False]) +parser.add_argument('--resume', type=str, default=None) + +# Datasets and loaders +parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') +parser.add_argument('--categories', type=str_list, default=['airplane']) +parser.add_argument('--scale_mode', type=str, default='shape_unit') +parser.add_argument('--train_batch_size', type=int, default=128) +parser.add_argument('--val_batch_size', type=int, default=32) +parser.add_argument('--num_workers', type=int, default=1) +parser.add_argument('--rotate', type=eval, default=False, choices=[True, False]) + +# Optimizer and scheduler +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=0) +parser.add_argument('--max_grad_norm', type=float, default=10) +parser.add_argument('--end_lr', type=float, default=1e-4) +parser.add_argument('--sched_start_epoch', type=int, default=150*THOUSAND) +parser.add_argument('--sched_end_epoch', type=int, default=300*THOUSAND) + +# Training +parser.add_argument('--seed', type=int, default=2020) +parser.add_argument('--logging', type=eval, default=True, choices=[True, False]) +parser.add_argument('--log_root', type=str, default='./logs_ae') +parser.add_argument('--device', type=str, default='cuda') +parser.add_argument('--max_iters', type=int, default=float('inf')) +parser.add_argument('--val_freq', type=float, default=1000) +parser.add_argument('--tag', type=str, default=None) +parser.add_argument('--num_val_batches', type=int, default=-1) +parser.add_argument('--num_inspect_batches', type=int, default=1) +parser.add_argument('--num_inspect_pointclouds', type=int, default=4) +args = parser.parse_args() +seed_all(args.seed) + +# Logging +if args.logging: + log_dir = get_new_log_dir(args.log_root, prefix='AE_', postfix='_' + args.tag if args.tag is not None else '') + logger = get_logger('train', log_dir) + writer = torch.utils.tensorboard.SummaryWriter(log_dir) + ckpt_mgr = CheckpointManager(log_dir) +else: + logger = get_logger('train', None) + writer = BlackHole() + ckpt_mgr = BlackHole() +logger.info(args) + +# Datasets and loaders +transform = None +if args.rotate: + transform = RandomRotate(180, ['pointcloud'], axis=1) +logger.info('Transform: %s' % repr(transform)) +logger.info('Loading datasets...') +train_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='train', + scale_mode=args.scale_mode, + transform=transform, +) +val_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='val', + scale_mode=args.scale_mode, + transform=transform, +) +train_iter = get_data_iterator(DataLoader( + train_dset, + batch_size=args.train_batch_size, + num_workers=args.num_workers +)) +val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, num_workers=args.num_workers) + + +# Model +logger.info('Building model...') +if args.resume is not None: + logger.info('Resuming from checkpoint...') + ckpt = torch.load(args.resume) + model = AutoEncoder(ckpt['args']).to(args.device) + model.load_state_dict(ckpt['state_dict']) +else: + model = AutoEncoder(args).to(args.device) +logger.info(repr(model)) + + +# Optimizer and scheduler +optimizer = torch.optim.Adam(model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay +) +scheduler = get_linear_scheduler( + optimizer, + start_epoch=args.sched_start_epoch, + end_epoch=args.sched_end_epoch, + start_lr=args.lr, + end_lr=args.end_lr +) + +# Train, validate +def train(it): + # Load data + batch = next(train_iter) + x = batch['pointcloud'].to(args.device) + + # Reset grad and model state + optimizer.zero_grad() + model.train() + + # Forward + loss = model.get_loss(x) + + # Backward and optimize + loss.backward() + orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + scheduler.step() + + logger.info('[Train] Iter %04d | Loss %.6f | Grad %.4f ' % (it, loss.item(), orig_grad_norm)) + writer.add_scalar('train/loss', loss, it) + writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) + writer.add_scalar('train/grad_norm', orig_grad_norm, it) + writer.flush() + +def validate_loss(it): + + all_refs = [] + all_recons = [] + for i, batch in enumerate(tqdm(val_loader, desc='Validate')): + if args.num_val_batches > 0 and i >= args.num_val_batches: + break + ref = batch['pointcloud'].to(args.device) + shift = batch['shift'].to(args.device) + scale = batch['scale'].to(args.device) + with torch.no_grad(): + model.eval() + code = model.encode(ref) + recons = model.decode(code, ref.size(1), flexibility=args.flexibility) + all_refs.append(ref * scale + shift) + all_recons.append(recons * scale + shift) + + all_refs = torch.cat(all_refs, dim=0) + all_recons = torch.cat(all_recons, dim=0) + metrics = EMD_CD(all_recons, all_refs, batch_size=args.val_batch_size, accelerated_cd=True) + cd, emd = metrics['MMD-CD'].item(), metrics['MMD-EMD'].item() + + logger.info('[Val] Iter %04d | CD %.6f | EMD %.6f ' % (it, cd, emd)) + writer.add_scalar('val/cd', cd, it) + writer.add_scalar('val/emd', emd, it) + writer.flush() + + return cd + +def validate_inspect(it): + sum_n = 0 + sum_chamfer = 0 + for i, batch in enumerate(tqdm(val_loader, desc='Inspect')): + x = batch['pointcloud'].to(args.device) + model.eval() + code = model.encode(x) + recons = model.decode(code, x.size(1), flexibility=args.flexibility).detach() + + sum_n += x.size(0) + if i >= args.num_inspect_batches: + break # Inspect only 5 batch + + writer.add_mesh('val/pointcloud', recons[:args.num_inspect_pointclouds], global_step=it) + writer.flush() + +# Main loop +logger.info('Start training...') +try: + it = 1 + while it <= args.max_iters: + train(it) + if it % args.val_freq == 0 or it == args.max_iters: + with torch.no_grad(): + cd_loss = validate_loss(it) + validate_inspect(it) + opt_states = { + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + } + ckpt_mgr.save(model, args, cd_loss, opt_states, step=it) + it += 1 + +except KeyboardInterrupt: + logger.info('Terminating...') diff --git a/train_gen.py b/train_gen.py new file mode 100644 index 0000000..38ad7f0 --- /dev/null +++ b/train_gen.py @@ -0,0 +1,222 @@ +import os +import math +import argparse +import torch +import torch.utils.tensorboard +from torch.utils.data import DataLoader +from torch.nn.utils import clip_grad_norm_ +from tqdm.auto import tqdm +import point_cloud_utils as pcu + +from utils.dataset import * +from utils.misc import * +from utils.data import * +from models.vae_gaussian import * +from models.vae_flow import * +from models.flow import add_spectral_norm, spectral_norm_power_iteration +from evaluation import * + + +# Arguments +parser = argparse.ArgumentParser() +# Model arguments +parser.add_argument('--model', type=str, default='flow', choices=['flow', 'gaussian']) +parser.add_argument('--latent_dim', type=int, default=256) +parser.add_argument('--num_steps', type=int, default=100) +parser.add_argument('--beta_1', type=float, default=1e-4) +parser.add_argument('--beta_T', type=float, default=0.02) +parser.add_argument('--sched_mode', type=str, default='linear') +parser.add_argument('--flexibility', type=float, default=0.0) +parser.add_argument('--truncate_std', type=float, default=2.0) +parser.add_argument('--latent_flow_depth', type=int, default=14) +parser.add_argument('--latent_flow_hidden_dim', type=int, default=256) +parser.add_argument('--num_samples', type=int, default=4) +parser.add_argument('--sample_num_points', type=int, default=2048) +parser.add_argument('--kl_weight', type=float, default=0.001) +parser.add_argument('--residual', type=eval, default=True, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) + +# Datasets and loaders +parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') +parser.add_argument('--categories', type=str_list, default=['airplane']) +parser.add_argument('--scale_mode', type=str, default='shape_unit') +parser.add_argument('--train_batch_size', type=int, default=128) +parser.add_argument('--val_batch_size', type=int, default=64) +parser.add_argument('--num_workers', type=int, default=4) + +# Optimizer and scheduler +parser.add_argument('--lr', type=float, default=2e-3) +parser.add_argument('--weight_decay', type=float, default=0) +parser.add_argument('--max_grad_norm', type=float, default=10) +parser.add_argument('--end_lr', type=float, default=1e-4) +parser.add_argument('--sched_start_epoch', type=int, default=200*THOUSAND) +parser.add_argument('--sched_end_epoch', type=int, default=400*THOUSAND) + +# Training +parser.add_argument('--seed', type=int, default=2020) +parser.add_argument('--logging', type=eval, default=True, choices=[True, False]) +parser.add_argument('--log_root', type=str, default='./logs_gen') +parser.add_argument('--device', type=str, default='cuda') +parser.add_argument('--max_iters', type=int, default=float('inf')) +parser.add_argument('--val_freq', type=int, default=1000) +parser.add_argument('--test_freq', type=int, default=30*THOUSAND) +parser.add_argument('--test_size', type=int, default=400) +parser.add_argument('--tag', type=str, default=None) +args = parser.parse_args() +seed_all(args.seed) + +# Logging +if args.logging: + log_dir = get_new_log_dir(args.log_root, prefix='GEN_', postfix='_' + args.tag if args.tag is not None else '') + logger = get_logger('train', log_dir) + writer = torch.utils.tensorboard.SummaryWriter(log_dir) + ckpt_mgr = CheckpointManager(log_dir) + log_hyperparams(writer, args) +else: + logger = get_logger('train', None) + writer = BlackHole() + ckpt_mgr = BlackHole() +logger.info(args) + +# Datasets and loaders +logger.info('Loading datasets...') + +train_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='train', + scale_mode=args.scale_mode, +) +val_dset = ShapeNetCore( + path=args.dataset_path, + cates=args.categories, + split='val', + scale_mode=args.scale_mode, +) +train_iter = get_data_iterator(DataLoader( + train_dset, + batch_size=args.train_batch_size, + num_workers=args.num_workers +)) + +# Model +logger.info('Building model...') +if args.model == 'gaussian': + model = GaussianVAE(args).to(args.device) +elif args.model == 'flow': + model = FlowVAE(args).to(args.device) +logger.info(repr(model)) +if args.spectral_norm: + add_spectral_norm(model, logger=logger) + +# Optimizer and scheduler +optimizer = torch.optim.Adam(model.parameters(), + lr=args.lr, + weight_decay=args.weight_decay +) +scheduler = get_linear_scheduler( + optimizer, + start_epoch=args.sched_start_epoch, + end_epoch=args.sched_end_epoch, + start_lr=args.lr, + end_lr=args.end_lr +) + +# Train, validate and test +def train(it): + # Load data + batch = next(train_iter) + x = batch['pointcloud'].to(args.device) + + # Reset grad and model state + optimizer.zero_grad() + model.train() + if args.spectral_norm: + spectral_norm_power_iteration(model, n_power_iterations=1) + + # Forward + kl_weight = args.kl_weight + loss = model.get_loss(x, kl_weight=kl_weight, writer=writer, it=it) + + # Backward and optimize + loss.backward() + orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + scheduler.step() + + logger.info('[Train] Iter %04d | Loss %.6f | Grad %.4f | KLWeight %.4f' % ( + it, loss.item(), orig_grad_norm, kl_weight + )) + writer.add_scalar('train/loss', loss, it) + writer.add_scalar('train/kl_weight', kl_weight, it) + writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) + writer.add_scalar('train/grad_norm', orig_grad_norm, it) + writer.flush() + +def validate_inspect(it): + z = torch.randn([args.num_samples, args.latent_dim]).to(args.device) + x = model.sample(z, args.sample_num_points, flexibility=args.flexibility) #, truncate_std=args.truncate_std) + writer.add_mesh('val/pointcloud', x, global_step=it) + writer.flush() + logger.info('[Inspect] Generating samples...') + +def test(it): + ref_pcs = [] + for i, data in enumerate(val_dset): + if i >= args.test_size: + break + ref_pcs.append(data['pointcloud'].unsqueeze(0)) + ref_pcs = torch.cat(ref_pcs, dim=0) + + gen_pcs = [] + for i in tqdm(range(0, math.ceil(args.test_size / args.val_batch_size)), 'Generate'): + with torch.no_grad(): + z = torch.randn([args.val_batch_size, args.latent_dim]).to(args.device) + x = model.sample(z, args.sample_num_points, flexibility=args.flexibility) + gen_pcs.append(x.detach().cpu()) + gen_pcs = torch.cat(gen_pcs, dim=0)[:args.test_size] + + # Denormalize point clouds, all shapes have zero mean. + # [WARNING]: Do NOT denormalize! + # ref_pcs *= val_dset.stats['std'] + # gen_pcs *= val_dset.stats['std'] + + with torch.no_grad(): + results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.val_batch_size, accelerated_cd=True) + results = {k:v.item() for k, v in results.items()} + jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy()) + results['jsd'] = jsd + + writer.add_scalar('test/Coverage_CD', results['lgan_cov-CD'], global_step=it) + writer.add_scalar('test/Coverage_EMD', results['lgan_cov-EMD'], global_step=it) + writer.add_scalar('test/MMD_CD', results['lgan_mmd-CD'], global_step=it) + writer.add_scalar('test/MMD_EMD', results['lgan_mmd-EMD'], global_step=it) + writer.add_scalar('test/1NN_CD', results['1-NN-CD-acc'], global_step=it) + writer.add_scalar('test/1NN_EMD', results['1-NN-EMD-acc'], global_step=it) + writer.add_scalar('test/JSD', results['jsd'], global_step=it) + + logger.info('[Test] Coverage | CD %.6f | EMD %.6f' % (results['lgan_cov-CD'], results['lgan_cov-EMD'])) + logger.info('[Test] MinMatDis | CD %.6f | EMD %.6f' % (results['lgan_mmd-CD'], results['lgan_mmd-EMD'])) + logger.info('[Test] 1NN-Accur | CD %.6f | EMD %.6f' % (results['1-NN-CD-acc'], results['1-NN-EMD-acc'])) + logger.info('[Test] JsnShnDis | %.6f ' % (results['jsd'])) + + +# Main loop +logger.info('Start training...') +try: + it = 1 + while it <= args.max_iters: + train(it) + if it % args.val_freq == 0 or it == args.max_iters: + validate_inspect(it) + opt_states = { + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), + } + ckpt_mgr.save(model, args, 0, others=opt_states, step=it) + if it % args.test_freq == 0 or it == args.max_iters: + test(it) + it += 1 + +except KeyboardInterrupt: + logger.info('Terminating...') diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000..ad4aecf --- /dev/null +++ b/utils/data.py @@ -0,0 +1,34 @@ +import torch +from torch.utils.data import DataLoader, random_split + + +def get_train_val_test_datasets(dataset, train_ratio, val_ratio): + assert (train_ratio + val_ratio) <= 1 + train_size = int(len(dataset) * train_ratio) + val_size = int(len(dataset) * val_ratio) + test_size = len(dataset) - train_size - val_size + + train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size]) + return train_set, val_set, test_set + + +def get_train_val_test_loaders(dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers): + train_set, val_set, test_set = get_train_val_test_datasets(dataset, train_ratio, val_ratio) + + train_loader = DataLoader(train_set, train_batch_size, shuffle=True, num_workers=num_workers) + val_loader = DataLoader(val_set, val_test_batch_size, shuffle=False, num_workers=num_workers) + test_loader = DataLoader(test_set, val_test_batch_size, shuffle=False, num_workers=num_workers) + + return train_loader, val_loader, test_loader + + +def get_data_iterator(iterable): + """Allows training with DataLoaders in a single infinite loop: + for i, data in enumerate(inf_generator(train_loader)): + """ + iterator = iterable.__iter__() + while True: + try: + yield iterator.__next__() + except StopIteration: + iterator = iterable.__iter__() diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..814b996 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,145 @@ +import os +import random +from copy import copy +import torch +from torch.utils.data import Dataset +import numpy as np +import h5py +from tqdm.auto import tqdm + + +synsetid_to_cate = { + '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', + '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', + '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', + '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', + '02954340': 'cap', '02958343': 'car', '03001627': 'chair', + '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', + '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', + '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', + '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', + '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', + '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', + '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', + '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', + '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', + '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', + '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', + '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', + '04554684': 'washer', '02992529': 'cellphone', + '02843684': 'birdhouse', '02871439': 'bookshelf', + # '02858304': 'boat', no boat in our dataset, merged into vessels + # '02834778': 'bicycle', not in our taxonomy +} +cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} + + +class ShapeNetCore(Dataset): + + GRAVITATIONAL_AXIS = 1 + + def __init__(self, path, cates, split, scale_mode, transform=None): + super().__init__() + assert isinstance(cates, list), '`cates` must be a list of cate names.' + assert split in ('train', 'val', 'test') + assert scale_mode is None or scale_mode in ('global_unit', 'shape_unit', 'shape_bbox', 'shape_half', 'shape_34') + self.path = path + if 'all' in cates: + cates = cate_to_synsetid.keys() + self.cate_synsetids = [cate_to_synsetid[s] for s in cates] + self.cate_synsetids.sort() + self.split = split + self.scale_mode = scale_mode + self.transform = transform + + self.pointclouds = [] + self.stats = None + + self.get_statistics() + self.load() + + def get_statistics(self): + + basename = os.path.basename(self.path) + dsetname = basename[:basename.rfind('.')] + stats_dir = os.path.join(os.path.dirname(self.path), dsetname + '_stats') + os.makedirs(stats_dir, exist_ok=True) + + if len(self.cate_synsetids) == len(cate_to_synsetid): + stats_save_path = os.path.join(stats_dir, 'stats_all.pt') + else: + stats_save_path = os.path.join(stats_dir, 'stats_' + '_'.join(self.cate_synsetids) + '.pt') + if os.path.exists(stats_save_path): + self.stats = torch.load(stats_save_path) + return self.stats + + with h5py.File(self.path, 'r') as f: + pointclouds = [] + for synsetid in self.cate_synsetids: + for split in ('train', 'val', 'test'): + pointclouds.append(torch.from_numpy(f[synsetid][split][...])) + + all_points = torch.cat(pointclouds, dim=0) # (B, N, 3) + B, N, _ = all_points.size() + mean = all_points.view(B*N, -1).mean(dim=0) # (1, 3) + std = all_points.view(-1).std(dim=0) # (1, ) + + self.stats = {'mean': mean, 'std': std} + torch.save(self.stats, stats_save_path) + return self.stats + + def load(self): + + def _enumerate_pointclouds(f): + for synsetid in self.cate_synsetids: + cate_name = synsetid_to_cate[synsetid] + for j, pc in enumerate(f[synsetid][self.split]): + yield torch.from_numpy(pc), j, cate_name + + with h5py.File(self.path, mode='r') as f: + for pc, pc_id, cate_name in _enumerate_pointclouds(f): + + if self.scale_mode == 'global_unit': + shift = pc.mean(dim=0).reshape(1, 3) + scale = self.stats['std'].reshape(1, 1) + elif self.scale_mode == 'shape_unit': + shift = pc.mean(dim=0).reshape(1, 3) + scale = pc.flatten().std().reshape(1, 1) + elif self.scale_mode == 'shape_half': + shift = pc.mean(dim=0).reshape(1, 3) + scale = pc.flatten().std().reshape(1, 1) / (0.5) + elif self.scale_mode == 'shape_34': + shift = pc.mean(dim=0).reshape(1, 3) + scale = pc.flatten().std().reshape(1, 1) / (0.75) + elif self.scale_mode == 'shape_bbox': + pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) + pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) + shift = ((pc_min + pc_max) / 2).view(1, 3) + scale = (pc_max - pc_min).max().reshape(1, 1) / 2 + else: + shift = torch.zeros([1, 3]) + scale = torch.ones([1, 1]) + + pc = (pc - shift) / scale + + self.pointclouds.append({ + 'pointcloud': pc, + 'cate': cate_name, + 'id': pc_id, + 'shift': shift, + 'scale': scale + }) + + # Deterministically shuffle the dataset + self.pointclouds.sort(key=lambda data: data['id'], reverse=False) + random.Random(2020).shuffle(self.pointclouds) + + def __len__(self): + return len(self.pointclouds) + + def __getitem__(self, idx): + data = {k:v.clone() if isinstance(v, torch.Tensor) else copy(v) for k, v in self.pointclouds[idx].items()} + if self.transform is not None: + data = self.transform(data) + return data + diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..6e5468a --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,162 @@ +import os +import torch +import numpy as np +import random +import time +import logging +import logging.handlers + +THOUSAND = 1000 +MILLION = 1000000 + + +class BlackHole(object): + def __setattr__(self, name, value): + pass + def __call__(self, *args, **kwargs): + return self + def __getattr__(self, name): + return self + + +class CheckpointManager(object): + + def __init__(self, save_dir, logger=BlackHole()): + super().__init__() + os.makedirs(save_dir, exist_ok=True) + self.save_dir = save_dir + self.ckpts = [] + self.logger = logger + + for f in os.listdir(self.save_dir): + if f[:4] != 'ckpt': + continue + _, score, it = f.split('_') + it = it.split('.')[0] + self.ckpts.append({ + 'score': float(score), + 'file': f, + 'iteration': int(it), + }) + + def get_worst_ckpt_idx(self): + idx = -1 + worst = float('-inf') + for i, ckpt in enumerate(self.ckpts): + if ckpt['score'] >= worst: + idx = i + worst = ckpt['score'] + return idx if idx >= 0 else None + + def get_best_ckpt_idx(self): + idx = -1 + best = float('inf') + for i, ckpt in enumerate(self.ckpts): + if ckpt['score'] <= best: + idx = i + best = ckpt['score'] + return idx if idx >= 0 else None + + def get_latest_ckpt_idx(self): + idx = -1 + latest_it = -1 + for i, ckpt in enumerate(self.ckpts): + if ckpt['iteration'] > latest_it: + idx = i + latest_it = ckpt['iteration'] + return idx if idx >= 0 else None + + def save(self, model, args, score, others=None, step=None): + + if step is None: + fname = 'ckpt_%.6f_.pt' % float(score) + else: + fname = 'ckpt_%.6f_%d.pt' % (float(score), int(step)) + path = os.path.join(self.save_dir, fname) + + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + 'others': others + }, path) + + self.ckpts.append({ + 'score': score, + 'file': fname + }) + + return True + + def load_best(self): + idx = self.get_best_ckpt_idx() + if idx is None: + raise IOError('No checkpoints found.') + ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) + return ckpt + + def load_latest(self): + idx = self.get_latest_ckpt_idx() + if idx is None: + raise IOError('No checkpoints found.') + ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) + return ckpt + + def load_selected(self, file): + ckpt = torch.load(os.path.join(self.save_dir, file)) + return ckpt + + +def seed_all(seed): + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_logger(name, log_dir=None): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + if log_dir is not None: + file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def get_new_log_dir(root='./logs', postfix='', prefix=''): + log_dir = os.path.join(root, prefix + time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) + postfix) + os.makedirs(log_dir) + return log_dir + + +def int_tuple(argstr): + return tuple(map(int, argstr.split(','))) + + +def str_tuple(argstr): + return tuple(argstr.split(',')) + + +def int_list(argstr): + return list(map(int, argstr.split(','))) + + +def str_list(argstr): + return list(argstr.split(',')) + + +def log_hyperparams(writer, args): + from torch.utils.tensorboard.summary import hparams + vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} + exp, ssi, sei = hparams(vars_args, {}) + writer.file_writer.add_summary(exp) + writer.file_writer.add_summary(ssi) + writer.file_writer.add_summary(sei) diff --git a/utils/transform.py b/utils/transform.py new file mode 100644 index 0000000..b7d7ab2 --- /dev/null +++ b/utils/transform.py @@ -0,0 +1,292 @@ +import torch +import numpy as np +import math +import random +import numbers +import random +from itertools import repeat + + +class Center(object): + r"""Centers node positions around the origin.""" + + def __init__(self, attr): + self.attr = attr + + def __call__(self, data): + for key in self.attr: + data[key] = data[key] - data[key].mean(dim=-2, keepdim=True) + return data + + def __repr__(self): + return '{}()'.format(self.__class__.__name__) + + +class NormalizeScale(object): + r"""Centers and normalizes node positions to the interval :math:`(-1, 1)`. + """ + + def __init__(self, attr): + self.center = Center(attr=attr) + self.attr = attr + + def __call__(self, data): + data = self.center(data) + + for key in self.attr: + scale = (1 / data[key].abs().max()) * 0.999999 + data[key] = data[key] * scale + + return data + + +class FixedPoints(object): + r"""Samples a fixed number of :obj:`num` points and features from a point + cloud. + Args: + num (int): The number of points to sample. + replace (bool, optional): If set to :obj:`False`, samples fixed + points without replacement. In case :obj:`num` is greater than + the number of points, duplicated points are kept to a + minimum. (default: :obj:`True`) + """ + + def __init__(self, num, replace=True): + self.num = num + self.replace = replace + # warnings.warn('FixedPoints is not deterministic') + + def __call__(self, data): + num_nodes = data['pos'].size(0) + data['dense'] = data['pos'] + + if self.replace: + choice = np.random.choice(num_nodes, self.num, replace=True) + else: + choice = torch.cat([ + torch.randperm(num_nodes) + for _ in range(math.ceil(self.num / num_nodes)) + ], dim=0)[:self.num] + + for key, item in data.items(): + if torch.is_tensor(item) and item.size(0) == num_nodes and key != 'dense': + data[key] = item[choice] + + return data + + def __repr__(self): + return '{}({}, replace={})'.format(self.__class__.__name__, self.num, + self.replace) + + +class LinearTransformation(object): + r"""Transforms node positions with a square transformation matrix computed + offline. + Args: + matrix (Tensor): tensor with shape :math:`[D, D]` where :math:`D` + corresponds to the dimensionality of node positions. + """ + + def __init__(self, matrix, attr): + assert matrix.dim() == 2, ( + 'Transformation matrix should be two-dimensional.') + assert matrix.size(0) == matrix.size(1), ( + 'Transformation matrix should be square. Got [{} x {}] rectangular' + 'matrix.'.format(*matrix.size())) + + self.matrix = matrix + self.attr = attr + + def __call__(self, data): + for key in self.attr: + pos = data[key].view(-1, 1) if data[key].dim() == 1 else data[key] + + assert pos.size(-1) == self.matrix.size(-2), ( + 'Node position matrix and transformation matrix have incompatible ' + 'shape.') + + data[key] = torch.matmul(pos, self.matrix.to(pos.dtype).to(pos.device)) + + return data + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self.matrix.tolist()) + + +class RandomRotate(object): + r"""Rotates node positions around a specific axis by a randomly sampled + factor within a given interval. + Args: + degrees (tuple or float): Rotation interval from which the rotation + angle is sampled. If :obj:`degrees` is a number instead of a + tuple, the interval is given by :math:`[-\mathrm{degrees}, + \mathrm{degrees}]`. + axis (int, optional): The rotation axis. (default: :obj:`0`) + """ + + def __init__(self, degrees, attr, axis=0): + if isinstance(degrees, numbers.Number): + degrees = (-abs(degrees), abs(degrees)) + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 + self.degrees = degrees + self.axis = axis + self.attr = attr + + def __call__(self, data): + degree = math.pi * random.uniform(*self.degrees) / 180.0 + sin, cos = math.sin(degree), math.cos(degree) + + if self.axis == 0: + matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] + elif self.axis == 1: + matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] + else: + matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] + return LinearTransformation(torch.tensor(matrix), attr=self.attr)(data) + + def __repr__(self): + return '{}({}, axis={})'.format(self.__class__.__name__, self.degrees, + self.axis) + + +class AddNoise(object): + + def __init__(self, std=0.01, noiseless_item_key='clean'): + self.std = std + self.key = noiseless_item_key + + def __call__(self, data): + data[self.key] = data['pos'] + data['pos'] = data['pos'] + torch.normal(mean=0, std=self.std, size=data['pos'].size()) + return data + + +class AddRandomNoise(object): + + def __init__(self, std_range=[0, 0.10], noiseless_item_key='clean'): + self.std_range = std_range + self.key = noiseless_item_key + + def __call__(self, data): + noise_std = random.uniform(*self.std_range) + data[self.key] = data['pos'] + data['pos'] = data['pos'] + torch.normal(mean=0, std=noise_std, size=data['pos'].size()) + return data + + +class AddNoiseForEval(object): + + def __init__(self, stds=[0.0, 0.01, 0.02, 0.03, 0.05, 0.10, 0.15]): + self.stds = stds + self.keys = ['noisy_%.2f' % s for s in stds] + + def __call__(self, data): + data['clean'] = data['pos'] + for noise_std in self.stds: + data['noisy_%.2f' % noise_std] = data['pos'] + torch.normal(mean=0, std=noise_std, size=data['pos'].size()) + return data + + +class IdentityTransform(object): + + def __call__(self, data): + return data + + +class RandomScale(object): + r"""Scales node positions by a randomly sampled factor :math:`s` within a + given interval, *e.g.*, resulting in the transformation matrix + .. math:: + \begin{bmatrix} + s & 0 & 0 \\ + 0 & s & 0 \\ + 0 & 0 & s \\ + \end{bmatrix} + for three-dimensional positions. + Args: + scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale + is randomly sampled from the range + :math:`a \leq \mathrm{scale} \leq b`. + """ + + def __init__(self, scales, attr): + assert isinstance(scales, (tuple, list)) and len(scales) == 2 + self.scales = scales + self.attr = attr + + def __call__(self, data): + scale = random.uniform(*self.scales) + for key in self.attr: + data[key] = data[key] * scale + return data + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self.scales) + + +class RandomTranslate(object): + r"""Translates node positions by randomly sampled translation values + within a given interval. In contrast to other random transformations, + translation is applied separately at each position. + Args: + translate (sequence or float or int): Maximum translation in each + dimension, defining the range + :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from. + If :obj:`translate` is a number instead of a sequence, the same + range is used for each dimension. + """ + + def __init__(self, translate, attr): + self.translate = translate + self.attr = attr + + def __call__(self, data): + (n, dim), t = data['pos'].size(), self.translate + if isinstance(t, numbers.Number): + t = list(repeat(t, times=dim)) + assert len(t) == dim + + ts = [] + for d in range(dim): + ts.append(data['pos'].new_empty(n).uniform_(-abs(t[d]), abs(t[d]))) + + for key in self.attr: + data[key] = data[key] + torch.stack(ts, dim=-1) + + return data + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, self.translate) + + +class Rotate(object): + r"""Rotates node positions around a specific axis by a randomly sampled + factor within a given interval. + Args: + degrees (tuple or float): Rotation interval from which the rotation + angle is sampled. If :obj:`degrees` is a number instead of a + tuple, the interval is given by :math:`[-\mathrm{degrees}, + \mathrm{degrees}]`. + axis (int, optional): The rotation axis. (default: :obj:`0`) + """ + + def __init__(self, degree, attr, axis=0): + self.degree = degree + self.axis = axis + self.attr = attr + + def __call__(self, data): + degree = math.pi * self.degree / 180.0 + sin, cos = math.sin(degree), math.cos(degree) + + if self.axis == 0: + matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] + elif self.axis == 1: + matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] + else: + matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] + return LinearTransformation(torch.tensor(matrix), attr=self.attr)(data) + + def __repr__(self): + return '{}({}, axis={})'.format(self.__class__.__name__, self.degrees, + self.axis)