-
Couldn't load subscription status.
- Fork 6
Merge DeltaFlow into codebase #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Kin-Zhang
wants to merge
35
commits into
KTH-RPL:main
Choose a base branch
from
Kin-Zhang:feature/deltaflow
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+2,193
−163
Open
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
b29f50b
cuda(histlib): cuda library from icp-flow project.
Kin-Zhang 4cf1168
docs: update README with icp-flow in the official implementation.
Kin-Zhang 21ca00c
docs: fix small typo, and starting updating model file.
Kin-Zhang 6302275
Merge branch 'main' into feature/icpflow
Kin-Zhang 345db96
feat(icp): core icp files, tested successfully.
Kin-Zhang bec8ddd
feat(deflowpp): update deflowpp model.
Kin-Zhang 9208a7e
feat(autolabel): update more autolabel through himo paper.
Kin-Zhang 222e065
feat(seflowpp): the full seflowpp process and train scripts.
Kin-Zhang aa642ef
fix(trainer): add seflowpp into trainer for folder name.
Kin-Zhang ff1df28
data(zod): update zod extraction scripts.
Kin-Zhang 39b6a18
!feat(lr): update optimizer to new structure.
Kin-Zhang 0ab8559
dcos(slurm): update slurm script, for reader to easy check how is the…
Kin-Zhang 3f6f18e
fix(env): fix some potiential env issue later
Kin-Zhang cf7b20a
fix(av2): instance label typo.
Kin-Zhang d0c7c59
fix(process): update key name in new version for seflow-variant process.
Kin-Zhang fbf0fe7
hotfix(eval): updating num_frames into eval.
Kin-Zhang 3c82d48
hotfix(eval/test): for history frames, we update keys' name and it ne…
Kin-Zhang 876cfd4
small fix on ssl_label to None if under supervise training.
Kin-Zhang f6744f1
feat(aug): merge data aug strategy from DeltaFlow project.
Kin-Zhang 6f18c75
docs(README): update readme.
Kin-Zhang 042c8f8
docs(README): update arxiv link
Kin-Zhang 5046a00
Merge branch 'main' into feature/icpflow
Kin-Zhang 5d89648
Merge remote-tracking branch 'kth-rpl/main'
Kin-Zhang 99b1103
feat(deltaflow): update deltaflow model file.
Kin-Zhang c0ebc92
conf: update deltaflow conf files
Kin-Zhang 6eb85cb
hotfix(eval): not assert but return if no gt class pt etc.
Kin-Zhang 477c580
Merge remote-tracking branch 'kth-rpl/main'
Kin-Zhang a888ee2
loss(deltaflow): add deltaflow loss.
Kin-Zhang 722ac2d
docs(README): update readme.
Kin-Zhang 0af9617
revert to OpenSceneFlow readme for the codebase.
Kin-Zhang e227dfc
docs: update README.
Kin-Zhang beeb97d
Merge branch 'main' into feature/deltaflow
Kin-Zhang 494b916
docs: update for diff opt methods.
Kin-Zhang af1dab0
hotfix(dataset): fix eval_mask in dataset.
Kin-Zhang 508ee10
docs(bib): add deltaflow bib back.
Kin-Zhang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| from torch import nn | ||
| from torch.autograd import Function | ||
| import torch | ||
| import importlib | ||
|
|
||
| import os, time | ||
| import hist | ||
|
|
||
| def histf(X, Y, min_x, min_y, min_z, max_x, max_y, max_z, len_x, len_y, len_z, mini_batch=8): | ||
| # print('hist cuda params: ', X.shape, Y.shape, | ||
| # min_x, min_y, min_z, | ||
| # max_x, max_y, max_z, | ||
| # len_x, len_y, len_z, | ||
| # ) | ||
| histogram = hist.hist(X.contiguous(), Y.contiguous(), | ||
| min_x, min_y, min_z, | ||
| max_x, max_y, max_z, | ||
| len_x, len_y, len_z, | ||
| mini_batch | ||
| ) | ||
| return histogram | ||
|
|
||
|
|
||
| torch.manual_seed(2022) | ||
|
|
||
| ######################## | ||
| def run_test(): | ||
| pts = torch.randn(3, 1000, 3) | ||
| indicators = torch.randint(0, 2, size=(3, 1000, 1)) | ||
| pts1 = torch.cat([pts, indicators], dim=-1) | ||
| pts2 = pts1.clone() | ||
| pts2[:, :,0] += 5. | ||
| pts2[:, :,1] += -3. | ||
| pts2[:, :,2] += -0.2 | ||
|
|
||
| range_x = 10. | ||
| range_y = 10. | ||
| range_z = 0.5 | ||
| thres =0.1 | ||
| # bins_x = torch.linspace(-range_x, range_x, int(2*range_x/thres)+1) | ||
| # bins_y = torch.linspace(-range_y, range_y, int(2*range_y/thres)+1) | ||
| # bins_z = torch.linspace(-range_z, range_z, int(2*range_z/thres)+1) | ||
| bins_x = torch.arange(-range_x, range_x+thres, thres) | ||
| bins_y = torch.arange(-range_y, range_y+thres, thres) | ||
| bins_z = torch.arange(-range_z, range_z+thres, thres) | ||
| print('bins_x: ', bins_x) | ||
| print('bins_z: ', bins_z) | ||
| pts1 = pts1.cuda() | ||
| pts2 = pts2.cuda() | ||
| bins_x = bins_x.cuda() | ||
| bins_y = bins_y.cuda() | ||
| bins_z = bins_z.cuda() | ||
|
|
||
| t_hists = histf(pts1, pts2, | ||
| -range_x, -range_y, -range_z, | ||
| range_x, range_y, range_z, | ||
| len(bins_x), len(bins_y), len(bins_z), | ||
| ) | ||
| print('output shape: ', t_hists.shape) | ||
| b, h, w, d = t_hists.shape | ||
| for t_hist in t_hists: | ||
| t_argmax = torch.argmax(t_hist) | ||
| print(f't_argmax: {t_argmax}, {t_hist.max()} {h}, {w}, {d}, {t_argmax//d//w%h}, {t_argmax//d%w}, {t_argmax%d}') | ||
| print('t_argmax', t_argmax//d//w%h, t_argmax//d%w, t_argmax%d, bins_x[t_argmax//d//w%h], bins_y[t_argmax//d%w], bins_z[t_argmax%d]) | ||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
| print("Pytorch version: ", torch.__version__) | ||
| print("GPU version: ", torch.cuda.get_device_name()) | ||
|
|
||
| run_test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| #include <vector> | ||
| #include "hist_cuda_core.cuh" | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/cuda/CUDAContext.h> | ||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| // #include <THC/THC.h> | ||
| // #include <THC/THCAtomics.cuh> | ||
| // #include <THC/THCDeviceUtils.cuh> | ||
|
|
||
| // extern THCState *state; | ||
|
|
||
| // author: Charles Shang | ||
| // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu | ||
|
|
||
|
|
||
| at::Tensor | ||
| hist_cuda(const at::Tensor &X, const at::Tensor &Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| const int mini_batch | ||
| ) | ||
| { | ||
| // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); | ||
|
|
||
| AT_ASSERTM(X.is_contiguous(), "input tensor has to be contiguous"); | ||
| AT_ASSERTM(Y.is_contiguous(), "input tensor has to be contiguous"); | ||
|
|
||
| AT_ASSERTM(X.type().is_cuda(), "input must be a CUDA tensor"); | ||
| AT_ASSERTM(Y.type().is_cuda(), "input must be a CUDA tensor"); | ||
|
|
||
| const int batch = X.size(0); | ||
| const int num_X = X.size(1); | ||
| const int dim = X.size(2); | ||
| const int num_Y = Y.size(1); | ||
|
|
||
| AT_ASSERTM((X.size(0) == Y.size(0)), "batch_X (%d) != batch_Y (%d).", X.size(0), Y.size(0)); | ||
| AT_ASSERTM((X.size(2) == Y.size(2)), "dim_X (%d) != dim_Y (%d).", X.size(2), Y.size(2)); | ||
|
|
||
| AT_ASSERTM((dim == 4), "dim (%d) != 4; 3 for (x, y, z); 1 for indicator,padded or not.", dim); | ||
|
|
||
| // printf("len: %d %d %f \n", len_x, len_y, len_z); | ||
| // printf("hist cuda coord: %f, %f, %f; %f, %f, %f; %f, %f, %f. \n", val_x, val_y, val_z, p_x, p_y, p_z, len_x, len_y, len_z); | ||
|
|
||
| // auto bins = at::zeros({batch, len_x, len_y, len_z}, X.options()); | ||
| // AT_DISPATCH_FLOATING_TYPES(X.type(), "hist_cuda_core", ([&] { | ||
| // hist_cuda_core(at::cuda::getCurrentCUDAStream(), | ||
| // X.data<scalar_t>(), Y.data<scalar_t>(), | ||
| // batch, dim, num_X, num_Y, | ||
| // min_x, min_y, min_z, | ||
| // max_x, max_y, max_z, | ||
| // len_x, len_y, len_z, | ||
| // bins.data<scalar_t>()); | ||
| // })); | ||
|
|
||
| auto bins = at::zeros({batch, len_x, len_y, len_z}, X.options()); | ||
|
|
||
| int iters = batch / mini_batch; | ||
| if (batch % mini_batch != 0) | ||
| { | ||
| iters += 1; | ||
| } | ||
|
|
||
| for (int i=0; i<iters; ++i) | ||
| { | ||
| int mini_batch_ = mini_batch; | ||
| if ((i+1) * mini_batch > batch) | ||
| { | ||
| mini_batch_ = batch - i * mini_batch; | ||
| } | ||
| // printf("iter: %d %d %d %d %d \n", i, iters, mini_batch_, mini_batch, batch); | ||
| AT_DISPATCH_FLOATING_TYPES(X.type(), "hist_cuda_core", ([&] { | ||
| hist_cuda_core(at::cuda::getCurrentCUDAStream(), | ||
| X.data<scalar_t>() + i*mini_batch*num_X*dim, | ||
| Y.data<scalar_t>() + i*mini_batch*num_Y*dim, | ||
| mini_batch_, dim, num_X, num_Y, | ||
| min_x, min_y, min_z, | ||
| max_x, max_y, max_z, | ||
| len_x, len_y, len_z, | ||
| bins.data<scalar_t>()+i*mini_batch*len_x*len_y*len_z); | ||
| })); | ||
| } | ||
|
|
||
|
|
||
|
|
||
| return bins; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| #pragma once | ||
| #include <torch/extension.h> | ||
|
|
||
| at::Tensor | ||
| hist(const at::Tensor &X, const at::Tensor &Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| const int mini_batch | ||
| ); | ||
|
|
||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| #include "hist.h" | ||
| #include "hist_cuda.h" | ||
|
|
||
| at::Tensor | ||
| hist(const at::Tensor &X, const at::Tensor &Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| const int mini_batch | ||
| ) | ||
| { | ||
|
|
||
| if (X.type().is_cuda() && Y.type().is_cuda()) | ||
| { | ||
| return hist_cuda(X, Y, | ||
| min_x, min_y, min_z, | ||
| max_x, max_y, max_z, | ||
| len_x, len_y, len_z, | ||
| mini_batch | ||
| ); | ||
| } | ||
| AT_ERROR("Not implemented on the CPU"); | ||
| } | ||
|
|
||
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
| m.def("hist", &hist, "hist"); | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| #pragma once | ||
| #include <torch/extension.h> | ||
|
|
||
| at::Tensor | ||
| hist_cuda(const at::Tensor &X, const at::Tensor &Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| const int mini_batch | ||
| ); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| #include <cstdio> | ||
| #include <algorithm> | ||
| #include <cstring> | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/cuda/CUDAContext.h> | ||
|
|
||
| // #include <THC/THC.h> | ||
| #include <THC/THCAtomics.cuh> | ||
| // #include <THC/THCDeviceUtils.cuh> | ||
|
|
||
| #define CUDA_KERNEL_LOOP(i, n) \ | ||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ | ||
| i < (n); \ | ||
| i += blockDim.x * gridDim.x) | ||
|
|
||
| const int CUDA_NUM_THREADS = 1024; | ||
| inline int GET_BLOCKS(const int N) | ||
| { | ||
| return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; | ||
| } | ||
|
|
||
| template <typename scalar_t> | ||
| __global__ void hist_cuda_kernel(const int n, | ||
| const scalar_t* X, | ||
| const scalar_t* Y, | ||
| const int batch, const int dim, | ||
| const int num_X, const int num_Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| scalar_t* bins | ||
| ) | ||
| { | ||
| CUDA_KERNEL_LOOP(index, n) | ||
| { | ||
| // index index of output matrix | ||
| // launch in parallel: batch * numX * numY; | ||
| // printf("hist cuda bin size: %d, %d, %d, %d. \n", batch, len_x, len_y, len_z); | ||
| const int b = index / num_X / num_Y % batch; | ||
| const int i = index / num_Y % num_X; | ||
| const int j = index % num_Y; | ||
|
|
||
| scalar_t flag_x = X[b*num_X*dim+i*dim+3]; | ||
| scalar_t flag_y = Y[b*num_Y*dim+j*dim+3]; | ||
| if (flag_x>0.0 && flag_y>0.0) | ||
| { | ||
| scalar_t val_x = X[b*num_X*dim+i*dim+0] - Y[b*num_Y*dim+j*dim+0]; | ||
| scalar_t val_y = X[b*num_X*dim+i*dim+1] - Y[b*num_Y*dim+j*dim+1]; | ||
| scalar_t val_z = X[b*num_X*dim+i*dim+2] - Y[b*num_Y*dim+j*dim+2]; | ||
| if (val_x >= min_x && val_x < max_x && val_y >= min_y && val_y < max_y && val_z >= min_z && val_z < max_z) | ||
| { | ||
| // [): left included; right excluded. | ||
| int p_x = __float2int_rd( (val_x-min_x) / (max_x-min_x) * __int2float_rd(len_x)); | ||
| int p_y = __float2int_rd( (val_y-min_y) / (max_y-min_y) * __int2float_rd(len_y)); | ||
| int p_z = __float2int_rd( (val_z-min_z) / (max_z-min_z) * __int2float_rd(len_z)); | ||
|
|
||
| // printf("hist cuda coord: %d, %d, %d, %d; %d, %d, %d, %d. \n", batch, len_x, len_y, len_z, b, p_x, p_y, p_z); | ||
| int bin_id = b*len_x*len_y*len_z + p_x*len_y*len_z + p_y*len_z + p_z; | ||
| atomicAdd(bins + bin_id, 1); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <typename scalar_t> | ||
| void hist_cuda_core(cudaStream_t stream, | ||
| const scalar_t* X, const scalar_t* Y, | ||
| const int batch, const int dim, | ||
| const int num_X, const int num_Y, | ||
| const float min_x, const float min_y, const float min_z, | ||
| const float max_x, const float max_y, const float max_z, | ||
| const int len_x, const int len_y, const int len_z, | ||
| scalar_t* bins | ||
| ) | ||
| { | ||
| const int num_kernels = batch * num_X * num_Y; | ||
| // printf("num kernels: %d\n", num_kernels); | ||
|
|
||
| // printf("hist cuda core: %f, %f, %f; %f, %f, %f; %f, %f, %f. \n", min_x, min_y, min_z, max_x, max_y, max_z, len_x, len_y, len_z); | ||
| // printf("hist cuda core: ", min_x, min_y, min_z, max_x, max_y, max_z, len_x, len_y, len_z, " \n"); | ||
| hist_cuda_kernel<scalar_t> | ||
| <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>( | ||
| num_kernels, | ||
| X, Y, | ||
| batch, dim, | ||
| num_X, num_Y, | ||
| min_x, min_y, min_z, | ||
| max_x, max_y, max_z, | ||
| len_x, len_y, len_z, | ||
| bins | ||
| ); | ||
|
|
||
| cudaError_t err = cudaGetLastError(); | ||
| if (err != cudaSuccess) | ||
| { | ||
| printf("error in hist_cuda_core: %s\n", cudaGetErrorString(err)); | ||
| } | ||
| } | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from setuptools import setup | ||
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
|
||
| setup( | ||
| name='hist', | ||
| ext_modules=[ | ||
| CUDAExtension('hist', [ | ||
| "/".join(__file__.split('/')[:-1] + ['hist_cuda.cpp']), # must named as xxx_cuda.cpp | ||
| "/".join(__file__.split('/')[:-1] + ['hist.cu']), | ||
| ]), | ||
| ], | ||
| cmdclass={ | ||
| 'build_ext': BuildExtension | ||
| }, | ||
| version='1.0.1') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the correct weight ckpt link here.