diff --git a/README.md b/README.md
index b702d83..85e8af9 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ A free and open source implementation of 3D [gaussian splatting](https://www.you
-OpenSplat takes camera poses + sparse points in [COLMAP](https://colmap.github.io/), [OpenSfM](https://github.com/mapillary/OpenSfM), [ODM](https://github.com/OpenDroneMap/ODM) or [nerfstudio](https://docs.nerf.studio/quickstart/custom_dataset.html) project format and computes a [scene file](https://drive.google.com/file/d/12lmvVWpFlFPL6nxl2e2d-4u4a31RCSKT/view?usp=sharing) (.ply) that can be later imported for [viewing](https://antimatter15.com/splat/?url=https://splat.uav4geo.com/banana.splat), editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations).
+OpenSplat takes camera poses + sparse points in [COLMAP](https://colmap.github.io/), [OpenSfM](https://github.com/mapillary/OpenSfM), [ODM](https://github.com/OpenDroneMap/ODM) or [nerfstudio](https://docs.nerf.studio/quickstart/custom_dataset.html) project format and computes a [scene file](https://drive.google.com/file/d/12lmvVWpFlFPL6nxl2e2d-4u4a31RCSKT/view?usp=sharing) (.ply or .splat) that can be later imported for [viewing](https://antimatter15.com/splat/?url=https://splat.uav4geo.com/banana.splat), editing and rendering in other [software](https://github.com/MrNeRF/awesome-3D-gaussian-splatting?tab=readme-ov-file#open-source-implementations).
Graphics card recommended, but not required! OpenSplat runs the fastest on NVIDIA, AMD and Apple (Metal) GPUs, but can also run entirely on the CPU (~100x slower).
@@ -223,6 +223,16 @@ There's several parameters you can tune. To view the full list:
./opensplat --help
```
+### Compression
+
+To generate compressed splats (.splat files), use the `-o` option:
+
+```bash
+./opensplat /path/to/banana -o banana.splat
+```
+
+### AMD GPU Notes
+
To train a model with AMD GPU using docker container, you can use the following command as a reference:
1. Launch the docker container with the following command:
```bash
@@ -243,7 +253,6 @@ We recently released OpenSplat, so there's lots of work to do.
* Improve speed / reduce memory usage
* Distributed computation using multiple machines
* Real-time training viewer output
- * Compressed scene outputs
* Automatic filtering
* Your ideas?
diff --git a/model.cpp b/model.cpp
index 10a2784..1f20642 100644
--- a/model.cpp
+++ b/model.cpp
@@ -1,3 +1,4 @@
+#include
#include "model.hpp"
#include "constants.hpp"
#include "tile_bounds.hpp"
@@ -12,6 +13,8 @@
#include
#endif
+namespace fs = std::filesystem;
+
torch::Tensor randomQuatTensor(long long n){
torch::Tensor u = torch::rand(n);
torch::Tensor v = torch::rand(n);
@@ -458,7 +461,16 @@ void Model::afterTrain(int step){
}
}
-void Model::savePlySplat(const std::string &filename){
+void Model::save(const std::string &filename){
+ if (fs::path(filename).extension().string() == ".splat"){
+ saveSplat(filename);
+ }else{
+ savePly(filename);
+ }
+ std::cout << "Wrote " << filename << std::endl;
+}
+
+void Model::savePly(const std::string &filename){
std::ofstream o(filename, std::ios::binary);
int numPoints = means.size(0);
@@ -515,7 +527,42 @@ void Model::savePlySplat(const std::string &filename){
}
o.close();
- std::cout << "Wrote " << filename << std::endl;
+}
+
+void Model::saveSplat(const std::string &filename){
+ std::ofstream o(filename, std::ios::binary);
+ int numPoints = means.size(0);
+
+ torch::Tensor meansCpu = keepCrs ? (means.cpu() / scale) + translation : means.cpu();
+ torch::Tensor scalesCpu = keepCrs ? (torch::exp(scales.cpu()) / scale) : torch::exp(scales.cpu());
+ torch::Tensor rgbsCpu = (sh2rgb(featuresDc.cpu()) * 255.0f).toType(torch::kUInt8);
+ torch::Tensor opac = (1.0f + torch::exp(-opacities.cpu()));
+ torch::Tensor opacitiesCpu = torch::clamp(((1.0f / opac) * 255.0f), 0.0f, 255.0f).toType(torch::kUInt8);
+ torch::Tensor quatsCpu = torch::clamp(quats.cpu() * 128.0f + 128.0f, 0.0f, 255.0f).toType(torch::kUInt8);
+
+ std::vector< size_t > splatIndices( numPoints );
+ std::iota( splatIndices.begin(), splatIndices.end(), 0 );
+ torch::Tensor order = (scalesCpu.index({"...", 0}) +
+ scalesCpu.index({"...", 1}) +
+ scalesCpu.index({"...", 2})) /
+ opac.index({"...", 0});
+ float *orderPtr = reinterpret_cast(order.data_ptr());
+
+ std::sort(splatIndices.begin(), splatIndices.end(),
+ [&orderPtr](size_t const &a, size_t const &b) {
+ return orderPtr[a] > orderPtr[b];
+ });
+
+ for (int i = 0; i < numPoints; i++){
+ size_t idx = splatIndices[i];
+
+ o.write(reinterpret_cast(meansCpu[idx].data_ptr()), sizeof(float) * 3);
+ o.write(reinterpret_cast(scalesCpu[idx].data_ptr()), sizeof(float) * 3);
+ o.write(reinterpret_cast(rgbsCpu[idx].data_ptr()), sizeof(uint8_t) * 3);
+ o.write(reinterpret_cast(opacitiesCpu[idx].data_ptr()), sizeof(uint8_t) * 1);
+ o.write(reinterpret_cast(quatsCpu[idx].data_ptr()), sizeof(uint8_t) * 4);
+ }
+ o.close();
}
void Model::saveDebugPly(const std::string &filename){
diff --git a/model.hpp b/model.hpp
index 17b8179..c54858a 100644
--- a/model.hpp
+++ b/model.hpp
@@ -81,7 +81,9 @@ struct Model{
void schedulersStep(int step);
int getDownscaleFactor(int step);
void afterTrain(int step);
- void savePlySplat(const std::string &filename);
+ void save(const std::string &filename);
+ void savePly(const std::string &filename);
+ void saveSplat(const std::string &filename);
void saveDebugPly(const std::string &filename);
torch::Tensor mainLoss(torch::Tensor &rgb, torch::Tensor >, float ssimWeight);
diff --git a/opensplat.cpp b/opensplat.cpp
index b820ecc..ccf4336 100644
--- a/opensplat.cpp
+++ b/opensplat.cpp
@@ -135,7 +135,7 @@ int main(int argc, char *argv[]){
if (saveEvery > 0 && step % saveEvery == 0){
fs::path p(outputScene);
- model.savePlySplat((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string()));
+ model.save((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string()));
}
if (!valRender.empty() && step % 10 == 0){
@@ -146,7 +146,7 @@ int main(int argc, char *argv[]){
}
}
- model.savePlySplat(outputScene);
+ model.save(outputScene);
// model.saveDebugPly("debug.ply");
// Validate
diff --git a/spherical_harmonics.cpp b/spherical_harmonics.cpp
index b64a467..53bee29 100644
--- a/spherical_harmonics.cpp
+++ b/spherical_harmonics.cpp
@@ -24,7 +24,7 @@ torch::Tensor rgb2sh(const torch::Tensor &rgb){
torch::Tensor sh2rgb(const torch::Tensor &sh){
// Converts from 0th spherical harmonic coefficients to RGB values [0,1]
- return (sh * C0) + 0.5;
+ return torch::clamp((sh * C0) + 0.5, 0.0f, 1.0f);
}
#if defined(USE_HIP) || defined(USE_CUDA) || defined(USE_MPS)