Skip to content

Commit

Permalink
Add --val-render
Browse files Browse the repository at this point in the history
  • Loading branch information
pierotofy committed Mar 20, 2024
1 parent ce3dff8 commit d36a80b
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ int main(int argc, char *argv[]){
("s,save-every", "Save output scene every these many steps (set to -1 to disable)", cxxopts::value<int>()->default_value("-1"))
("val", "Withhold a camera shot for validating the scene loss")
("val-image", "Filename of the image to withhold for validating scene loss", cxxopts::value<std::string>()->default_value("random"))
("val-render", "Path of the directory where to render validation images", cxxopts::value<std::string>()->default_value(""))
("cpu", "Force CPU execution")

("n,num-iters", "Number of iterations to run", cxxopts::value<int>()->default_value("30000"))
Expand Down Expand Up @@ -57,8 +58,10 @@ int main(int argc, char *argv[]){
const std::string projectRoot = result["input"].as<std::string>();
const std::string outputScene = result["output"].as<std::string>();
const int saveEvery = result["save-every"].as<int>();
const bool validate = result.count("val") > 0;
const bool validate = result.count("val") > 0 || result.count("val-render") > 0;
const std::string valImage = result["val-image"].as<std::string>();
const std::string valRender = result["val-render"].as<std::string>();
if (!valRender.empty() && !fs::exists(valRender)) fs::create_directories(valRender);

const float downScaleFactor = (std::max)(result["downscale-factor"].as<float>(), 1.0f);
const int numIters = result["num-iters"].as<int>();
Expand All @@ -79,6 +82,7 @@ int main(int argc, char *argv[]){
torch::Device device = torch::kCPU;
int displayStep = 1;


if (torch::cuda::is_available() && result.count("cpu") == 0) {
std::cout << "Using CUDA" << std::endl;
device = torch::kCUDA;
Expand Down Expand Up @@ -132,6 +136,13 @@ int main(int argc, char *argv[]){
fs::path p(outputScene);
model.savePlySplat((p.replace_filename(fs::path(p.stem().string() + "_" + std::to_string(step) + p.extension().string())).string()));
}

if (!valRender.empty() && step % displayStep == 0){
torch::Tensor rgb = model.forward(*valCam, step);
cv::Mat image = tensorToImage(rgb.detach().cpu());
cv::cvtColor(image, image, cv::COLOR_RGB2BGR);
cv::imwrite((fs::path(valRender) / (std::to_string(step) + ".png")).string(), image);
}
}

model.savePlySplat(outputScene);
Expand Down

0 comments on commit d36a80b

Please sign in to comment.