Skip to content

Batch kernels for backward pass of Preprocessing #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: mlsys/forward_preprocess_batch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c408be1
preprocess batches for backward outline
sandeepnmenon Apr 27, 2024
6444aa7
solve syntax errors in backward
prapti19 Apr 27, 2024
8df1b63
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
f59e67c
Merge branch 'mlsys/batched_preprocess' of github.com:TarzanZhao/diff…
sandeepnmenon Apr 28, 2024
529710b
added focal_x and focal_y calculation inside the kernel
sandeepnmenon Apr 28, 2024
3ac1ad3
Refactor rasterization_tests.py to use raster_settings_batch instead …
sandeepnmenon Apr 28, 2024
ce314e2
fixed namedtuple setting bug
sandeepnmenon Apr 28, 2024
fdd3b4f
Refactor GaussianRasterizationSettings class to handle raster_setting…
sandeepnmenon Apr 28, 2024
cdf3bc1
remove focal_x and focal_y calculations
sandeepnmenon Apr 28, 2024
591a5c1
Refactor CUDA rasterizer code to include width and height parameters …
sandeepnmenon Apr 28, 2024
1b54023
Renamed W and H to image_width and image_height parameters in prepr…
sandeepnmenon Apr 28, 2024
9294a07
reverted focal_x and focal_y removal in normal preprocessBackward
sandeepnmenon Apr 28, 2024
877677f
grad_means2D to handle more than 2 dimensions
sandeepnmenon Apr 28, 2024
444e8a5
add tests for backward
prapti19 Apr 28, 2024
46b83eb
ruff formatting and gradients for remaining inputs
sandeepnmenon Apr 28, 2024
710b56f
Add pyproject.toml file with ruff line-length set to 120
sandeepnmenon Apr 28, 2024
2ca5ae6
Refactor ruff.toml file to set line-length to 120 and indent-width to 4
sandeepnmenon Apr 28, 2024
7a4b6b4
Refactor compare_tensors function to handle None values in rasterizat…
sandeepnmenon Apr 28, 2024
14889e6
Update ruff.toml file to set line-length to 120
sandeepnmenon Apr 28, 2024
5682d26
Refactor rasterization_backward_tests.py to include gradient checks f…
sandeepnmenon Apr 28, 2024
945e8cf
gradients calculated for all the variables to check and cloning them
sandeepnmenon Apr 28, 2024
c84d7cd
converted to pytest testing
sandeepnmenon Apr 28, 2024
6f38446
fixed colon bug and ruff formatiting
sandeepnmenon Apr 28, 2024
7b30782
Add __pycache__/ to .gitignore
sandeepnmenon Apr 28, 2024
13a5559
renamed to *_test.py
sandeepnmenon Apr 28, 2024
5b24881
Update .gitignore to include __pycache__/
sandeepnmenon Apr 28, 2024
9e6f4a9
moved test into tests folder
sandeepnmenon Apr 28, 2024
7be38fa
Add instructions for running tests in README.md
sandeepnmenon Apr 28, 2024
1887e14
Merge branch 'mlsys/forward_preprocess_batch' of github.com:TarzanZha…
sandeepnmenon Apr 28, 2024
307e156
deleted old test file
sandeepnmenon Apr 28, 2024
21ee225
renamed idx to point_idx and view_idx to result_idx in backward
sandeepnmenon Apr 29, 2024
363b4ee
moved from python time to torch record
sandeepnmenon May 8, 2024
91d1582
fixed num_points in preprocessForwardBatches
sandeepnmenon May 8, 2024
ee767da
Refactor test function names for clarity and consistency
sandeepnmenon May 11, 2024
2e7f032
fixed but in printing only first 5 non matching indices
sandeepnmenon May 11, 2024
e8edb86
fixed backward bug of backward kernel not getting executed
sandeepnmenon May 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fixed backward bug of backward kernel not getting executed
  • Loading branch information
sandeepnmenon committed May 11, 2024
commit e8edb865befa8fded6a85d5ce4aef55b540bf84f
11 changes: 5 additions & 6 deletions cuda_rasterizer/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ __device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int m
float z = dir.z;

// Target location for this Gaussian to write SH gradients to
glm::vec3* dL_dsh = dL_dshs + point_idx * max_coeffs;
glm::vec3 *dL_dsh = dL_dshs + result_idx * max_coeffs;

// No tricks here, just high school-level calculus.
// No tricks here, just high school-level calculus.
float dRGBdsh0 = SH_C0;
dL_dsh[0] = dRGBdsh0 * dL_dRGB;
if (deg > 0)
Expand All @@ -55,7 +55,7 @@ __device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int m
dL_dsh[2] = dRGBdsh2 * dL_dRGB;
dL_dsh[3] = dRGBdsh3 * dL_dRGB;

dRGBdx = -SH_C1 * sh[3];
dRGBdx = -SH_C1 * sh[3];
dRGBdy = -SH_C1 * sh[1];
dRGBdz = SH_C1 * sh[2];

Expand All @@ -75,7 +75,7 @@ __device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int m
dL_dsh[7] = dRGBdsh7 * dL_dRGB;
dL_dsh[8] = dRGBdsh8 * dL_dRGB;

dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];

Expand All @@ -96,7 +96,7 @@ __device__ void computeColorFromSH(int point_idx, int result_idx, int deg, int m
dL_dsh[14] = dRGBdsh14 * dL_dRGB;
dL_dsh[15] = dRGBdsh15 * dL_dRGB;

dRGBdx += (
dRGBdx += (
SH_C3[0] * sh[9] * 3.f * 2.f * xy +
SH_C3[1] * sh[10] * yz +
SH_C3[2] * sh[11] * -2.f * xy +
Expand Down Expand Up @@ -563,7 +563,6 @@ __global__ void preprocessCUDABatched(
auto point_idx = blockIdx.x * blockDim.x + threadIdx.x;
auto viewpoint_idx = blockIdx.y;
if (viewpoint_idx >= num_viewpoints || point_idx >= P) return;
return;

auto idx = viewpoint_idx * P + point_idx;
if (!(radii[idx] > 0))
Expand Down
3 changes: 1 addition & 2 deletions diff_gaussian_rasterization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def backward(ctx, grad_means2D, grad_rgb, grad_conic_opacity, grad_radii, grad_d
# grad_means2D is (P, 2) now. Need to pad it to (P, 3) because preprocess_gaussians_backward's cuda implementation.

grad_means2D_pad = torch.zeros_like(grad_means2D[..., :1], dtype = grad_means2D.dtype, device=grad_means2D.device)
grad_means2D = torch.cat((grad_means2D, grad_means2D_pad), dim = 1).contiguous()

grad_means2D = torch.cat((grad_means2D, grad_means2D_pad), dim = -1).contiguous()
# Restructure args as C++ method expects them
args = (radii,
cov3D,
Expand Down
8 changes: 4 additions & 4 deletions tests/rasterization_preprocess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def setup_data():
)


def compute_dummy_loss(means3D, scales, rotations, shs, opacity):
losses = [(tensor - torch.ones_like(tensor)).pow(2).mean() for tensor in [means3D, scales, rotations, shs, opacity]]
def compute_dummy_loss(batched_means2D, batched_rgb, batched_conic_opacity):
losses = [(tensor - torch.ones_like(tensor)).pow(2).mean() for tensor in [batched_means2D, batched_conic_opacity, batched_rgb]]
loss = sum(losses)
return loss

Expand Down Expand Up @@ -186,7 +186,7 @@ def run_batched_gaussian_rasterizer(setup_data):
torch.cuda.synchronize()
start_backward_event.record()

loss = compute_dummy_loss(means3D, scales, rotations, shs, opacity)
loss = compute_dummy_loss(batched_means2D, batched_rgb, batched_conic_opacity)
loss.backward()

end_backward_event.record()
Expand Down Expand Up @@ -312,7 +312,7 @@ def run_batched_gaussian_rasterizer_batch_processing(setup_data):
torch.cuda.synchronize()
start_backward_event.record()

loss = compute_dummy_loss(means3D, scales, rotations, shs, opacity)
loss = compute_dummy_loss(batched_means2D, batched_rgb, batched_conic_opacity)
loss.backward()

end_backward_event.record()
Expand Down