Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 61 additions & 89 deletions src/fvdb/GaussianSplat3d.cpp

Large diffs are not rendered by default.

55 changes: 22 additions & 33 deletions src/fvdb/detail/autograd/GaussianRasterize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <fvdb/detail/utils/Nvtx.h>
#include <fvdb/detail/utils/Utils.h>

#include <nanovdb/math/Math.h>

namespace fvdb::detail::autograd {

RasterizeGaussiansToPixels::VariableList
Expand All @@ -18,20 +16,18 @@ RasterizeGaussiansToPixels::forward(
const RasterizeGaussiansToPixels::Variable &conics, // [C, N, 3]
const RasterizeGaussiansToPixels::Variable &colors, // [C, N, 3]
const RasterizeGaussiansToPixels::Variable &opacities, // [N]
const uint32_t imageWidth,
const uint32_t imageHeight,
const uint32_t imageOriginW,
const uint32_t imageOriginH,
const uint32_t tileSize,
const RasterizeGaussiansToPixels::Variable &tileOffsets, // [C, tile_height, tile_width]
const RasterizeGaussiansToPixels::Variable &tileGaussianIds, // [n_isects]
const ops::RenderWindow2D &renderWindow,
const ops::DenseTileIntersections &tileIntersections,
const bool absgrad,
std::optional<RasterizeGaussiansToPixels::Variable> backgrounds,
std::optional<RasterizeGaussiansToPixels::Variable> masks) {
FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixels::forward");

const auto &tileOffsets = tileIntersections.tileOffsets();
const auto &tileGaussianIds = tileIntersections.tileGaussianIds();
const uint32_t tileSize = tileIntersections.tileSize();

auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() {
const ops::RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH};
return ops::dispatchGaussianRasterizeForward<DeviceTag>(means2d,
conics,
colors,
Expand Down Expand Up @@ -63,11 +59,11 @@ RasterizeGaussiansToPixels::forward(
}
ctx->save_for_backward(toSave);

ctx->saved_data["imageWidth"] = (int64_t)imageWidth;
ctx->saved_data["imageHeight"] = (int64_t)imageHeight;
ctx->saved_data["imageWidth"] = (int64_t)renderWindow.width;
ctx->saved_data["imageHeight"] = (int64_t)renderWindow.height;
ctx->saved_data["tileSize"] = (int64_t)tileSize;
ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW;
ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH;
ctx->saved_data["imageOriginW"] = (int64_t)renderWindow.originW;
ctx->saved_data["imageOriginH"] = (int64_t)renderWindow.originH;
ctx->saved_data["absgrad"] = absgrad;

return {renderedColors, renderedAlphas};
Expand All @@ -80,7 +76,6 @@ RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext
Variable dLossDRenderedColors = gradOutput.at(0);
Variable dLossDRenderedAlphas = gradOutput.at(1);

// ensure the gradients are contiguous if they are not None
if (dLossDRenderedColors.defined()) {
dLossDRenderedColors = dLossDRenderedColors.contiguous();
}
Expand Down Expand Up @@ -110,18 +105,15 @@ RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext
masks = saved.at(optIdx++);
}

const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt();
const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt();
const int tileSize = (int)ctx->saved_data["tileSize"].toInt();
const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt();
const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt();
const bool absgrad = ctx->saved_data["absgrad"].toBool();
const ops::RenderWindow2D renderWindow{
static_cast<uint32_t>(ctx->saved_data["imageWidth"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageHeight"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginW"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginH"].toInt())};
const int tileSize = (int)ctx->saved_data["tileSize"].toInt();
const bool absgrad = ctx->saved_data["absgrad"].toBool();

auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() {
const ops::RenderWindow2D renderWindow{static_cast<uint32_t>(imageWidth),
static_cast<uint32_t>(imageHeight),
static_cast<uint32_t>(imageOriginW),
static_cast<uint32_t>(imageOriginH)};
return ops::dispatchGaussianRasterizeBackward<DeviceTag>(means2d,
conics,
colors,
Expand Down Expand Up @@ -150,19 +142,16 @@ RasterizeGaussiansToPixels::backward(RasterizeGaussiansToPixels::AutogradContext
Variable dLossDColors = std::get<3>(variables);
Variable dLossDOpacities = std::get<4>(variables);

// 9 forward params (excluding ctx): means2d, conics, colors, opacities,
// renderWindow, tileIntersections, absgrad, backgrounds, masks
return {
dLossDMeans2d,
dLossDConics,
dLossDColors,
dLossDOpacities,
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(), // renderWindow
Variable(), // tileIntersections
Variable(), // absgrad
Variable(), // backgrounds
Variable(), // masks
};
Expand Down
11 changes: 4 additions & 7 deletions src/fvdb/detail/autograd/GaussianRasterize.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#ifndef FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H
#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZE_H

#include <fvdb/detail/ops/gsplat/GaussianTileIntersection.h>

#include <torch/autograd.h>

namespace fvdb::detail::autograd {
Expand All @@ -18,13 +20,8 @@ struct RasterizeGaussiansToPixels : public torch::autograd::Function<RasterizeGa
const Variable &conics, // [C, N, 3]
const Variable &colors, // [C, N, 3]
const Variable &opacities, // [N]
const uint32_t imageWidth,
const uint32_t imageHeight,
const uint32_t imageOriginW,
const uint32_t imageOriginH,
const uint32_t tileSize,
const Variable &tileOffsets, // [C, tile_height, tile_width]
const Variable &tileGaussianIds, // [n_isects]
const ops::RenderWindow2D &renderWindow,
const ops::DenseTileIntersections &tileIntersections,
const bool absgrad,
std::optional<Variable> backgrounds = std::nullopt, // [C, D]
std::optional<Variable> masks = std::nullopt); // [C, tileH, tileW]
Expand Down
68 changes: 40 additions & 28 deletions src/fvdb/detail/autograd/GaussianRasterizeFromWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,21 @@ RasterizeGaussiansToPixelsFromWorld3DGS::forward(
const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &distortionCoeffs,
const fvdb::detail::ops::RollingShutterType rollingShutterType,
const fvdb::detail::ops::DistortionModel distortionModel,
const uint32_t imageWidth,
const uint32_t imageHeight,
const uint32_t imageOriginW,
const uint32_t imageOriginH,
const uint32_t tileSize,
const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &tileOffsets,
const RasterizeGaussiansToPixelsFromWorld3DGS::Variable &tileGaussianIds,
const ops::RenderWindow2D &renderWindow,
const ops::DenseTileIntersections &tileIntersections,
std::optional<RasterizeGaussiansToPixelsFromWorld3DGS::Variable> backgrounds,
std::optional<RasterizeGaussiansToPixelsFromWorld3DGS::Variable> masks) {
FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsFromWorld3DGS::forward");

const auto &tileOffsets = tileIntersections.tileOffsets();
const auto &tileGaussianIds = tileIntersections.tileGaussianIds();
const uint32_t tileSize = tileIntersections.tileSize();

fvdb::detail::ops::RenderSettings settings;
settings.imageWidth = imageWidth;
settings.imageHeight = imageHeight;
settings.imageOriginW = imageOriginW;
settings.imageOriginH = imageOriginH;
settings.imageWidth = renderWindow.width;
settings.imageHeight = renderWindow.height;
settings.imageOriginW = renderWindow.originW;
settings.imageOriginH = renderWindow.originH;
settings.tileSize = tileSize;

auto outputs = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() {
Expand Down Expand Up @@ -95,10 +94,10 @@ RasterizeGaussiansToPixelsFromWorld3DGS::forward(
}
ctx->save_for_backward(toSave);

ctx->saved_data["imageWidth"] = (int64_t)imageWidth;
ctx->saved_data["imageHeight"] = (int64_t)imageHeight;
ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW;
ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH;
ctx->saved_data["imageWidth"] = (int64_t)renderWindow.width;
ctx->saved_data["imageHeight"] = (int64_t)renderWindow.height;
ctx->saved_data["imageOriginW"] = (int64_t)renderWindow.originW;
ctx->saved_data["imageOriginH"] = (int64_t)renderWindow.originH;
ctx->saved_data["tileSize"] = (int64_t)tileSize;
ctx->saved_data["distortionModel"] = (int64_t)distortionModel;
ctx->saved_data["rollingShutterType"] = (int64_t)rollingShutterType;
Expand Down Expand Up @@ -148,21 +147,22 @@ RasterizeGaussiansToPixelsFromWorld3DGS::backward(
masks = saved.at(optIdx++);
}

const uint32_t imageWidth = (uint32_t)ctx->saved_data["imageWidth"].toInt();
const uint32_t imageHeight = (uint32_t)ctx->saved_data["imageHeight"].toInt();
const uint32_t imageOriginW = (uint32_t)ctx->saved_data["imageOriginW"].toInt();
const uint32_t imageOriginH = (uint32_t)ctx->saved_data["imageOriginH"].toInt();
const uint32_t tileSize = (uint32_t)ctx->saved_data["tileSize"].toInt();
const ops::RenderWindow2D renderWindow{
static_cast<uint32_t>(ctx->saved_data["imageWidth"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageHeight"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginW"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginH"].toInt())};
const uint32_t tileSize = (uint32_t)ctx->saved_data["tileSize"].toInt();
const auto distortionModel =
static_cast<fvdb::detail::ops::DistortionModel>(ctx->saved_data["distortionModel"].toInt());
const auto rollingShutterType = static_cast<fvdb::detail::ops::RollingShutterType>(
ctx->saved_data["rollingShutterType"].toInt());

fvdb::detail::ops::RenderSettings settings;
settings.imageWidth = imageWidth;
settings.imageHeight = imageHeight;
settings.imageOriginW = imageOriginW;
settings.imageOriginH = imageOriginH;
settings.imageWidth = renderWindow.width;
settings.imageHeight = renderWindow.height;
settings.imageOriginW = renderWindow.originW;
settings.imageOriginH = renderWindow.originH;
settings.tileSize = tileSize;

auto grads = FVDB_DISPATCH_KERNEL_DEVICE(means.device(), [&]() {
Expand Down Expand Up @@ -195,10 +195,22 @@ RasterizeGaussiansToPixelsFromWorld3DGS::backward(
Variable dFeatures = std::get<3>(grads);
Variable dOpacities = std::get<4>(grads);

// Return gradients in the same order as forward inputs.
return {dMeans, dQuats, dLogScales, dFeatures, dOpacities, Variable(), Variable(),
Variable(), Variable(), Variable(), Variable(), Variable(), Variable(), Variable(),
Variable(), Variable(), Variable(), Variable(), Variable(), Variable()};
// Return gradients in the same order as forward inputs (excluding ctx).
return {dMeans,
dQuats,
dLogScales,
dFeatures,
dOpacities,
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable(),
Variable()};
}

} // namespace fvdb::detail::autograd
10 changes: 3 additions & 7 deletions src/fvdb/detail/autograd/GaussianRasterizeFromWorld.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define FVDB_DETAIL_AUTOGRAD_GAUSSIANRASTERIZEFROMWORLD_H

#include <fvdb/detail/ops/gsplat/GaussianCameras.cuh>
#include <fvdb/detail/ops/gsplat/GaussianTileIntersection.h>

#include <torch/autograd.h>

Expand Down Expand Up @@ -32,13 +33,8 @@ struct RasterizeGaussiansToPixelsFromWorld3DGS
const Variable &distortionCoeffs, // [C,K]
const fvdb::detail::ops::RollingShutterType rollingShutterType,
const fvdb::detail::ops::DistortionModel distortionModel,
const uint32_t imageWidth,
const uint32_t imageHeight,
const uint32_t imageOriginW,
const uint32_t imageOriginH,
const uint32_t tileSize,
const Variable &tileOffsets, // [C, tileH, tileW]
const Variable &tileGaussianIds, // [n_isects]
const ops::RenderWindow2D &renderWindow,
const ops::DenseTileIntersections &tileIntersections,
std::optional<Variable> backgrounds = std::nullopt, // [C,D]
std::optional<Variable> masks = std::nullopt); // [C,tileH,tileW] bool

Expand Down
63 changes: 23 additions & 40 deletions src/fvdb/detail/autograd/GaussianRasterizeSparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <fvdb/detail/utils/Nvtx.h>
#include <fvdb/detail/utils/Utils.h>

#include <nanovdb/math/Math.h>

namespace fvdb::detail::autograd {

RasterizeGaussiansToPixelsSparse::VariableList
Expand All @@ -19,26 +17,22 @@ RasterizeGaussiansToPixelsSparse::forward(
const RasterizeGaussiansToPixelsSparse::Variable &conics, // [C, N, 3]
const RasterizeGaussiansToPixelsSparse::Variable &colors, // [C, N, 3]
const RasterizeGaussiansToPixelsSparse::Variable &opacities, // [N]
const uint32_t imageWidth,
const uint32_t imageHeight,
const uint32_t imageOriginW,
const uint32_t imageOriginH,
const uint32_t tileSize,
const RasterizeGaussiansToPixelsSparse::Variable
&tileOffsets, // [C, tile_height, tile_width] (dense) or [num_active_tiles + 1] (sparse)
const RasterizeGaussiansToPixelsSparse::Variable &tileGaussianIds, // [n_isects]
const RasterizeGaussiansToPixelsSparse::Variable &activeTiles, // [num_active_tiles]
const RasterizeGaussiansToPixelsSparse::Variable
&tilePixelMask, // [num_active_tiles, tileSize, tileSize]
const RasterizeGaussiansToPixelsSparse::Variable &tilePixelCumsum, // [num_active_tiles + 1]
const RasterizeGaussiansToPixelsSparse::Variable &pixelMap, // [num_pixels]
const ops::RenderWindow2D &renderWindow,
const ops::SparseTileIntersections &tileIntersections,
const bool absgrad,
std::optional<RasterizeGaussiansToPixelsSparse::Variable> backgrounds,
std::optional<RasterizeGaussiansToPixelsSparse::Variable> masks) {
FVDB_FUNC_RANGE_WITH_NAME("RasterizeGaussiansToPixelsSparse::forward");

const auto &tileOffsets = tileIntersections.tileOffsets();
const auto &tileGaussianIds = tileIntersections.tileGaussianIds();
const auto &activeTiles = tileIntersections.activeTiles();
const auto &tilePixelMask = tileIntersections.tilePixelMask();
const auto &tilePixelCumsum = tileIntersections.tilePixelCumsum();
const auto &pixelMap = tileIntersections.pixelMap();
const uint32_t tileSize = tileIntersections.tileSize();

auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() {
const ops::RenderWindow2D renderWindow{imageWidth, imageHeight, imageOriginW, imageOriginH};
return ops::dispatchGaussianSparseRasterizeForward<DeviceTag>(pixelsToRender,
means2d,
conics,
Expand Down Expand Up @@ -95,11 +89,11 @@ RasterizeGaussiansToPixelsSparse::forward(
}
ctx->save_for_backward(toSave);

ctx->saved_data["imageWidth"] = (int64_t)imageWidth;
ctx->saved_data["imageHeight"] = (int64_t)imageHeight;
ctx->saved_data["imageWidth"] = (int64_t)renderWindow.width;
ctx->saved_data["imageHeight"] = (int64_t)renderWindow.height;
ctx->saved_data["tileSize"] = (int64_t)tileSize;
ctx->saved_data["imageOriginW"] = (int64_t)imageOriginW;
ctx->saved_data["imageOriginH"] = (int64_t)imageOriginH;
ctx->saved_data["imageOriginW"] = (int64_t)renderWindow.originW;
ctx->saved_data["imageOriginH"] = (int64_t)renderWindow.originH;
ctx->saved_data["numOuterLists"] = (int64_t)numOuterLists;
ctx->saved_data["absgrad"] = absgrad;

Expand All @@ -114,7 +108,6 @@ RasterizeGaussiansToPixelsSparse::backward(
Variable dLossDRenderedFeaturesJData = gradOutput.at(0);
Variable dLossDRenderedAlphasJData = gradOutput.at(1);

// ensure the gradients are contiguous if they are not None
if (dLossDRenderedFeaturesJData.defined()) {
dLossDRenderedFeaturesJData = dLossDRenderedFeaturesJData.contiguous();
}
Expand Down Expand Up @@ -153,11 +146,12 @@ RasterizeGaussiansToPixelsSparse::backward(
masks = saved.at(optIdx++);
}

const int imageWidth = (int)ctx->saved_data["imageWidth"].toInt();
const int imageHeight = (int)ctx->saved_data["imageHeight"].toInt();
const ops::RenderWindow2D renderWindow{
static_cast<uint32_t>(ctx->saved_data["imageWidth"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageHeight"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginW"].toInt()),
static_cast<uint32_t>(ctx->saved_data["imageOriginH"].toInt())};
const int tileSize = (int)ctx->saved_data["tileSize"].toInt();
const int imageOriginW = (int)ctx->saved_data["imageOriginW"].toInt();
const int imageOriginH = (int)ctx->saved_data["imageOriginH"].toInt();
const int64_t numOuterLists = ctx->saved_data["numOuterLists"].toInt();
const bool absgrad = ctx->saved_data["absgrad"].toBool();

Expand All @@ -170,10 +164,6 @@ RasterizeGaussiansToPixelsSparse::backward(
auto dLossDRenderedAlphas = pixelsToRender.jagged_like(dLossDRenderedAlphasJData);

auto variables = FVDB_DISPATCH_KERNEL(means2d.device(), [&]() {
const ops::RenderWindow2D renderWindow{static_cast<uint32_t>(imageWidth),
static_cast<uint32_t>(imageHeight),
static_cast<uint32_t>(imageOriginW),
static_cast<uint32_t>(imageOriginH)};
return ops::dispatchGaussianSparseRasterizeBackward<DeviceTag>(pixelsToRender,
means2d,
conics,
Expand Down Expand Up @@ -207,23 +197,16 @@ RasterizeGaussiansToPixelsSparse::backward(
Variable dLossDColors = std::get<3>(variables);
Variable dLossDOpacities = std::get<4>(variables);

// 10 forward params (excluding ctx): pixelsToRender, means2d, conics, features,
// opacities, renderWindow, tileIntersections, absgrad, backgrounds, masks
return {
Variable(), // pixelsToRender
dLossDMeans2d, // means2d
dLossDConics, // conics
dLossDColors, // features
dLossDOpacities, // opacities
Variable(), // imageWidth
Variable(), // imageHeight
Variable(), // imageOriginW
Variable(), // imageOriginH
Variable(), // tileSize
Variable(), // tileOffsets
Variable(), // tileGaussianIds
Variable(), // activeTiles
Variable(), // tilePixelMask
Variable(), // tilePixelCumsum
Variable(), // pixelMap
Variable(), // renderWindow
Variable(), // tileIntersections
Variable(), // absgrad
Variable(), // backgrounds
Variable(), // masks
Expand Down
Loading
Loading