Skip to content

Commit

Permalink
Merge pull request #37 from salovision/main
Browse files Browse the repository at this point in the history
General fixes that came up with Windows build
  • Loading branch information
pierotofy authored Mar 12, 2024
2 parents e5952b9 + 2b4ada5 commit bb8dbd6
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 37 deletions.
3 changes: 2 additions & 1 deletion cv_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ cv::Mat tensorToImage(const torch::Tensor &t){
if (c != 3) throw std::runtime_error("Only images with 3 channels are supported");

cv::Mat image(h, w, type);
uint8_t *dataPtr = static_cast<uint8_t *>((t * 255.0).toType(torch::kU8).data_ptr());
torch::Tensor scaledTensor = (t * 255.0).toType(torch::kU8);
uint8_t* dataPtr = static_cast<uint8_t*>(scaledTensor.data_ptr());
std::copy(dataPtr, dataPtr + (w * h * c), image.data);

return image;
Expand Down
9 changes: 0 additions & 9 deletions input_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,6 @@ std::vector<float> Camera::undistortionParameters(){
return p;
}

void Camera::scaleOutputResolution(float scaleFactor){
fx = fx * scaleFactor;
fy = fy * scaleFactor;
cx = cx * scaleFactor;
cy = cy * scaleFactor;
height = static_cast<int>(static_cast<float>(height) * scaleFactor);
width = static_cast<int>(static_cast<float>(width) * scaleFactor);
}

std::tuple<std::vector<Camera>, Camera *> InputData::getCameras(bool validate, const std::string &valImage){
if (!validate) return std::make_tuple(cameras, nullptr);
else{
Expand Down
1 change: 0 additions & 1 deletion input_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ struct Camera{
torch::Tensor getIntrinsicsMatrix();
bool hasDistortionParameters();
std::vector<float> undistortionParameters();
void scaleOutputResolution(float scaleFactor);
torch::Tensor getImage(int downscaleFactor);

void loadImage(float downscaleFactor);
Expand Down
47 changes: 23 additions & 24 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ torch::Tensor l1(const torch::Tensor& rendered, const torch::Tensor& gt){

torch::Tensor Model::forward(Camera& cam, int step){

float scaleFactor = 1.0f / static_cast<float>(getDownscaleFactor(step));
cam.scaleOutputResolution(scaleFactor);
const float scaleFactor = getDownscaleFactor(step);
const float fx = cam.fx / scaleFactor;
const float fy = cam.fy / scaleFactor;
const float cx = cam.cx / scaleFactor;
const float cy = cam.cy / scaleFactor;
const int height = static_cast<int>(static_cast<float>(cam.height) / scaleFactor);
const int width = static_cast<int>(static_cast<float>(cam.width) / scaleFactor);

// TODO: these can be moved to Camera and computed only once?
torch::Tensor R = cam.camToWorld.index({Slice(None, 3), Slice(None, 3)});
Expand All @@ -58,20 +63,20 @@ torch::Tensor Model::forward(Camera& cam, int step){
torch::Tensor Rinv = R.transpose(0, 1);
torch::Tensor Tinv = torch::matmul(-Rinv, T);

lastHeight = cam.height;
lastWidth = cam.width;
lastHeight = height;
lastWidth = width;

torch::Tensor viewMat = torch::eye(4, device);
viewMat.index_put_({Slice(None, 3), Slice(None, 3)}, Rinv);
viewMat.index_put_({Slice(None, 3), Slice(3, 4)}, Tinv);

float fovX = 2.0f * std::atan(cam.width / (2.0f * cam.fx));
float fovY = 2.0f * std::atan(cam.height / (2.0f * cam.fy));
float fovX = 2.0f * std::atan(width / (2.0f * fx));
float fovY = 2.0f * std::atan(height / (2.0f * fy));

torch::Tensor projMat = projectionMatrix(0.001f, 1000.0f, fovX, fovY, device);

TileBounds tileBounds = std::make_tuple((cam.width + BLOCK_X - 1) / BLOCK_X,
(cam.height + BLOCK_Y - 1) / BLOCK_Y,
TileBounds tileBounds = std::make_tuple((width + BLOCK_X - 1) / BLOCK_X,
(height + BLOCK_Y - 1) / BLOCK_Y,
1);

torch::Tensor colors = torch::cat({featuresDc.index({Slice(), None, Slice()}), featuresRest}, 1);
Expand All @@ -82,12 +87,12 @@ torch::Tensor Model::forward(Camera& cam, int step){
quats / quats.norm(2, {-1}, true),
viewMat,
torch::matmul(projMat, viewMat),
cam.fx,
cam.fy,
cam.cx,
cam.cy,
cam.height,
cam.width,
fx,
fy,
cx,
cy,
height,
width,
tileBounds);
xys = p[0];
torch::Tensor depths = p[1];
Expand All @@ -96,11 +101,8 @@ torch::Tensor Model::forward(Camera& cam, int step){
torch::Tensor numTilesHit = p[4];


if (radii.sum().item<float>() == 0.0f){
// Rescale resolution back
cam.scaleOutputResolution(1.0f / scaleFactor);
return backgroundColor.repeat({cam.height, cam.width, 1});
}
if (radii.sum().item<float>() == 0.0f)
return backgroundColor.repeat({height, width, 1});

// TODO: is this needed?
xys.retain_grad();
Expand All @@ -120,15 +122,12 @@ torch::Tensor Model::forward(Camera& cam, int step){
numTilesHit,
rgbs, // TODO: why not sigmod?
torch::sigmoid(opacities),
cam.height,
cam.width,
height,
width,
backgroundColor);

rgb = torch::clamp_max(rgb, 1.0f);

// Rescale resolution back
cam.scaleOutputResolution(1.0f / scaleFactor);

return rgb;
}

Expand Down
6 changes: 4 additions & 2 deletions opensplat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,13 @@ int main(int argc, char *argv[]){
numIters,
device);

InfiniteRandomIterator<Camera> camsIter(cams);
std::vector< size_t > camIndices( cams.size() );
std::iota( camIndices.begin(), camIndices.end(), 0 );
InfiniteRandomIterator<size_t> camsIter( camIndices );

int imageSize = -1;
for (size_t step = 1; step <= numIters; step++){
Camera cam = camsIter.next();
Camera& cam = cams[ camsIter.next() ];

model.optimizersZeroGrad();

Expand Down

0 comments on commit bb8dbd6

Please sign in to comment.