From 12e8169c50941557c6a3950658dece472356ebd1 Mon Sep 17 00:00:00 2001 From: Piero Toffanin Date: Fri, 19 Apr 2024 15:47:39 -0400 Subject: [PATCH] Add .splat compression output support --- README.md | 13 +++++++++-- model.cpp | 51 +++++++++++++++++++++++++++++++++++++++-- model.hpp | 4 +++- opensplat.cpp | 4 ++-- spherical_harmonics.cpp | 2 +- 5 files changed, 66 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index b702d83..85e8af9 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ A free and open source implementation of 3D [gaussian splatting](https://www.you -OpenSplat takes camera poses + sparse points in [COLMAP](https://colmap.github.io/), [OpenSfM](https://github.com/mapillary/OpenSfM), [ODM](https://github.com/OpenDroneMap/ODM) or [nerfstudio](https://docs.nerf.studio/quickstart/custom_dataset.html) project format and computes a [scene file](https://drive.google.com/file/d/12lmvVWpFlFPL6nxl2e2d-4u4a31RCSKT/view?usp=sharing) (.ply) that can be later imported for [viewing](https://antimatter15.com/splat/?url=https://splat.uav4geo.com/banana.splat), editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations). +OpenSplat takes camera poses + sparse points in [COLMAP](https://colmap.github.io/), [OpenSfM](https://github.com/mapillary/OpenSfM), [ODM](https://github.com/OpenDroneMap/ODM) or [nerfstudio](https://docs.nerf.studio/quickstart/custom_dataset.html) project format and computes a [scene file](https://drive.google.com/file/d/12lmvVWpFlFPL6nxl2e2d-4u4a31RCSKT/view?usp=sharing) (.ply or .splat) that can be later imported for [viewing](https://antimatter15.com/splat/?url=https://splat.uav4geo.com/banana.splat), editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations). Graphics card recommended, but not required! OpenSplat runs the fastest on NVIDIA, AMD and Apple (Metal) GPUs, but can also run entirely on the CPU (~100x slower). @@ -223,6 +223,16 @@ There's several parameters you can tune. To view the full list: ./opensplat --help ``` +### Compression + +To generate compressed splats (.splat files), use the `-o` option: + +```bash +./opensplat /path/to/banana -o banana.splat +``` + +### AMD GPU Notes + To train a model with AMD GPU using docker container, you can use the following command as a reference: 1. Launch the docker container with the following command: ```bash @@ -243,7 +253,6 @@ We recently released OpenSplat, so there's lots of work to do. * Improve speed / reduce memory usage * Distributed computation using multiple machines * Real-time training viewer output - * Compressed scene outputs * Automatic filtering * Your ideas? diff --git a/model.cpp b/model.cpp index 10a2784..1f20642 100644 --- a/model.cpp +++ b/model.cpp @@ -1,3 +1,4 @@ +#include #include "model.hpp" #include "constants.hpp" #include "tile_bounds.hpp" @@ -12,6 +13,8 @@ #include #endif +namespace fs = std::filesystem; + torch::Tensor randomQuatTensor(long long n){ torch::Tensor u = torch::rand(n); torch::Tensor v = torch::rand(n); @@ -458,7 +461,16 @@ void Model::afterTrain(int step){ } } -void Model::savePlySplat(const std::string &filename){ +void Model::save(const std::string &filename){ + if (fs::path(filename).extension().string() == ".splat"){ + saveSplat(filename); + }else{ + savePly(filename); + } + std::cout << "Wrote " << filename << std::endl; +} + +void Model::savePly(const std::string &filename){ std::ofstream o(filename, std::ios::binary); int numPoints = means.size(0); @@ -515,7 +527,42 @@ void Model::savePlySplat(const std::string &filename){ } o.close(); - std::cout << "Wrote " << filename << std::endl; +} + +void Model::saveSplat(const std::string &filename){ + std::ofstream o(filename, std::ios::binary); + int numPoints = means.size(0); + + torch::Tensor meansCpu = keepCrs ? (means.cpu() / scale) + translation : means.cpu(); + torch::Tensor scalesCpu = keepCrs ? (torch::exp(scales.cpu()) / scale) : torch::exp(scales.cpu()); + torch::Tensor rgbsCpu = (sh2rgb(featuresDc.cpu()) * 255.0f).toType(torch::kUInt8); + torch::Tensor opac = (1.0f + torch::exp(-opacities.cpu())); + torch::Tensor opacitiesCpu = torch::clamp(((1.0f / opac) * 255.0f), 0.0f, 255.0f).toType(torch::kUInt8); + torch::Tensor quatsCpu = torch::clamp(quats.cpu() * 128.0f + 128.0f, 0.0f, 255.0f).toType(torch::kUInt8); + + std::vector< size_t > splatIndices( numPoints ); + std::iota( splatIndices.begin(), splatIndices.end(), 0 ); + torch::Tensor order = (scalesCpu.index({"...", 0}) + + scalesCpu.index({"...", 1}) + + scalesCpu.index({"...", 2})) / + opac.index({"...", 0}); + float *orderPtr = reinterpret_cast(order.data_ptr()); + + std::sort(splatIndices.begin(), splatIndices.end(), + [&orderPtr](size_t const &a, size_t const &b) { + return orderPtr[a] > orderPtr[b]; + }); + + for (int i = 0; i < numPoints; i++){ + size_t idx = splatIndices[i]; + + o.write(reinterpret_cast(meansCpu[idx].data_ptr()), sizeof(float) * 3); + o.write(reinterpret_cast(scalesCpu[idx].data_ptr()), sizeof(float) * 3); + o.write(reinterpret_cast(rgbsCpu[idx].data_ptr()), sizeof(uint8_t) * 3); + o.write(reinterpret_cast(opacitiesCpu[idx].data_ptr()), sizeof(uint8_t) * 1); + o.write(reinterpret_cast(quatsCpu[idx].data_ptr()), sizeof(uint8_t) * 4); + } + o.close(); } void Model::saveDebugPly(const std::string &filename){ diff --git a/model.hpp b/model.hpp index 17b8179..c54858a 100644 --- a/model.hpp +++ b/model.hpp @@ -81,7 +81,9 @@ struct Model{ void schedulersStep(int step); int getDownscaleFactor(int step); void afterTrain(int step); - void savePlySplat(const std::string &filename); + void save(const std::string &filename); + void savePly(const std::string &filename); + void saveSplat(const std::string &filename); void saveDebugPly(const std::string &filename); torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor >, float ssimWeight); diff --git a/opensplat.cpp b/opensplat.cpp index b820ecc..ccf4336 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -135,7 +135,7 @@ int main(int argc, char *argv[]){ if (saveEvery > 0 && step % saveEvery == 0){ fs::path p(outputScene); - model.savePlySplat((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string())); + model.save((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string())); } if (!valRender.empty() && step % 10 == 0){ @@ -146,7 +146,7 @@ int main(int argc, char *argv[]){ } } - model.savePlySplat(outputScene); + model.save(outputScene); // model.saveDebugPly("debug.ply"); // Validate diff --git a/spherical_harmonics.cpp b/spherical_harmonics.cpp index b64a467..53bee29 100644 --- a/spherical_harmonics.cpp +++ b/spherical_harmonics.cpp @@ -24,7 +24,7 @@ torch::Tensor rgb2sh(const torch::Tensor &rgb){ torch::Tensor sh2rgb(const torch::Tensor &sh){ // Converts from 0th spherical harmonic coefficients to RGB values [0,1] - return (sh * C0) + 0.5; + return torch::clamp((sh * C0) + 0.5, 0.0f, 1.0f); } #if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS)