diff --git a/cv_utils.cpp b/cv_utils.cpp index 4a06cad..800453b 100644 --- a/cv_utils.cpp +++ b/cv_utils.cpp @@ -29,7 +29,8 @@ cv::Mat tensorToImage(const torch::Tensor &t){ if (c != 3) throw std::runtime_error("Only images with 3 channels are supported"); cv::Mat image(h, w, type); - uint8_t *dataPtr = static_cast((t * 255.0).toType(torch::kU8).data_ptr()); + torch::Tensor scaledTensor = (t * 255.0).toType(torch::kU8); + uint8_t* dataPtr = static_cast(scaledTensor.data_ptr()); std::copy(dataPtr, dataPtr + (w * h * c), image.data); return image; diff --git a/input_data.cpp b/input_data.cpp index bf8e045..e1c33e0 100644 --- a/input_data.cpp +++ b/input_data.cpp @@ -111,15 +111,6 @@ std::vector Camera::undistortionParameters(){ return p; } -void Camera::scaleOutputResolution(float scaleFactor){ - fx = fx * scaleFactor; - fy = fy * scaleFactor; - cx = cx * scaleFactor; - cy = cy * scaleFactor; - height = static_cast(static_cast(height) * scaleFactor); - width = static_cast(static_cast(width) * scaleFactor); -} - std::tuple, Camera *> InputData::getCameras(bool validate, const std::string &valImage){ if (!validate) return std::make_tuple(cameras, nullptr); else{ diff --git a/input_data.hpp b/input_data.hpp index 2505c39..ac2fc74 100644 --- a/input_data.hpp +++ b/input_data.hpp @@ -37,7 +37,6 @@ struct Camera{ torch::Tensor getIntrinsicsMatrix(); bool hasDistortionParameters(); std::vector undistortionParameters(); - void scaleOutputResolution(float scaleFactor); torch::Tensor getImage(int downscaleFactor); void loadImage(float downscaleFactor); diff --git a/model.cpp b/model.cpp index 915c334..e8a7c58 100644 --- a/model.cpp +++ b/model.cpp @@ -44,8 +44,13 @@ torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){ torch::Tensor Model::forward(Camera& cam, int step){ - float scaleFactor = 1.0f / static_cast(getDownscaleFactor(step)); - cam.scaleOutputResolution(scaleFactor); + const float scaleFactor = getDownscaleFactor(step); + const float fx = cam.fx / scaleFactor; + const float fy = cam.fy / scaleFactor; + const float cx = cam.cx / scaleFactor; + const float cy = cam.cy / scaleFactor; + const int height = static_cast(static_cast(cam.height) / scaleFactor); + const int width = static_cast(static_cast(cam.width) / scaleFactor); // TODO: these can be moved to Camera and computed only once? torch::Tensor R = cam.camToWorld.index({Slice(None, 3), Slice(None, 3)}); @@ -58,20 +63,20 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor Rinv = R.transpose(0, 1); torch::Tensor Tinv = torch::matmul(-Rinv, T); - lastHeight = cam.height; - lastWidth = cam.width; + lastHeight = height; + lastWidth = width; torch::Tensor viewMat = torch::eye(4, device); viewMat.index_put_({Slice(None, 3), Slice(None, 3)}, Rinv); viewMat.index_put_({Slice(None, 3), Slice(3, 4)}, Tinv); - float fovX = 2.0f * std::atan(cam.width / (2.0f * cam.fx)); - float fovY = 2.0f * std::atan(cam.height / (2.0f * cam.fy)); + float fovX = 2.0f * std::atan(width / (2.0f * fx)); + float fovY = 2.0f * std::atan(height / (2.0f * fy)); torch::Tensor projMat = projectionMatrix(0.001f, 1000.0f, fovX, fovY, device); - TileBounds tileBounds = std::make_tuple((cam.width + BLOCK_X - 1) / BLOCK_X, - (cam.height + BLOCK_Y - 1) / BLOCK_Y, + TileBounds tileBounds = std::make_tuple((width + BLOCK_X - 1) / BLOCK_X, + (height + BLOCK_Y - 1) / BLOCK_Y, 1); torch::Tensor colors = torch::cat({featuresDc.index({Slice(), None, Slice()}), featuresRest}, 1); @@ -82,12 +87,12 @@ torch::Tensor Model::forward(Camera& cam, int step){ quats / quats.norm(2, {-1}, true), viewMat, torch::matmul(projMat, viewMat), - cam.fx, - cam.fy, - cam.cx, - cam.cy, - cam.height, - cam.width, + fx, + fy, + cx, + cy, + height, + width, tileBounds); xys = p[0]; torch::Tensor depths = p[1]; @@ -96,11 +101,8 @@ torch::Tensor Model::forward(Camera& cam, int step){ torch::Tensor numTilesHit = p[4]; - if (radii.sum().item() == 0.0f){ - // Rescale resolution back - cam.scaleOutputResolution(1.0f / scaleFactor); - return backgroundColor.repeat({cam.height, cam.width, 1}); - } + if (radii.sum().item() == 0.0f) + return backgroundColor.repeat({height, width, 1}); // TODO: is this needed? xys.retain_grad(); @@ -120,15 +122,12 @@ torch::Tensor Model::forward(Camera& cam, int step){ numTilesHit, rgbs, // TODO: why not sigmod? torch::sigmoid(opacities), - cam.height, - cam.width, + height, + width, backgroundColor); rgb = torch::clamp_max(rgb, 1.0f); - // Rescale resolution back - cam.scaleOutputResolution(1.0f / scaleFactor); - return rgb; } diff --git a/opensplat.cpp b/opensplat.cpp index 234038a..9dc7daa 100644 --- a/opensplat.cpp +++ b/opensplat.cpp @@ -104,11 +104,13 @@ int main(int argc, char *argv[]){ numIters, device); - InfiniteRandomIterator camsIter(cams); + std::vector< size_t > camIndices( cams.size() ); + std::iota( camIndices.begin(), camIndices.end(), 0 ); + InfiniteRandomIterator camsIter( camIndices ); int imageSize = -1; for (size_t step = 1; step <= numIters; step++){ - Camera cam = camsIter.next(); + Camera& cam = cams[ camsIter.next() ]; model.optimizersZeroGrad();