Skip to content

Commit b0dfe34

Browse files
author
Hexu Zhao
committed
move get_local2j_ids to the python part code.
1 parent fcac049 commit b0dfe34

File tree

1 file changed

+0
-63
lines changed

1 file changed

+0
-63
lines changed

diff_gaussian_rasterization/__init__.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -351,69 +351,6 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
351351
cuda_args
352352
)
353353

354-
def get_local2j_ids(self, means2D, radii, cuda_args):
355-
# For each 3dgs, calculate the set of GPUs that will use this 3dgs for rendering.
356-
357-
if isinstance(cuda_args["dist_global_strategy"], str):
358-
raster_settings = self.raster_settings
359-
mp_world_size = int(cuda_args["mp_world_size"])
360-
mp_rank = int(cuda_args["mp_rank"])
361-
362-
# TODO: make it more general.
363-
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
364-
assert len(dist_global_strategy) == mp_world_size+1, "dist_global_strategy should have length WORLD_SIZE+1"
365-
assert dist_global_strategy[0] == 0, "dist_global_strategy[0] should be 0"
366-
dist_global_strategy = torch.tensor(dist_global_strategy, dtype=torch.int, device=means2D.device)
367-
368-
args = (
369-
raster_settings.image_height,
370-
raster_settings.image_width,
371-
mp_rank,
372-
mp_world_size,
373-
means2D,
374-
radii,
375-
dist_global_strategy,
376-
cuda_args
377-
)
378-
379-
local2j_ids_bool = _C.get_local2j_ids_bool(*args) # local2j_ids_bool is (P, world_size) bool tensor
380-
381-
else:
382-
raster_settings = self.raster_settings
383-
mp_world_size = int(cuda_args["mp_world_size"])
384-
mp_rank = int(cuda_args["mp_rank"])
385-
386-
division_pos = cuda_args["dist_global_strategy"]
387-
division_pos_xs, division_pos_ys = division_pos
388-
389-
rectangles = []
390-
for y_rank in range(len(division_pos_ys[0])-1):
391-
for x_rank in range(len(division_pos_ys)):
392-
local_tile_x_l, local_tile_x_r = division_pos_xs[x_rank], division_pos_xs[x_rank+1]
393-
local_tile_y_l, local_tile_y_r = division_pos_ys[x_rank][y_rank], division_pos_ys[x_rank][y_rank+1]
394-
rectangles.append([local_tile_y_l, local_tile_y_r, local_tile_x_l, local_tile_x_r])
395-
rectangles = torch.tensor(rectangles, dtype=torch.int, device=means2D.device)# (mp_world_size, 4)
396-
397-
args = (
398-
raster_settings.image_height,
399-
raster_settings.image_width,
400-
mp_rank,
401-
mp_world_size,
402-
means2D,
403-
radii,
404-
rectangles,
405-
cuda_args
406-
)
407-
408-
local2j_ids_bool = _C.get_local2j_ids_bool_adjust_mode6(*args) # local2j_ids_bool is (P, world_size) bool tensor
409-
410-
local2j_ids = []
411-
for rk in range(mp_world_size):
412-
local2j_ids.append(local2j_ids_bool[:, rk].nonzero())
413-
414-
return local2j_ids, local2j_ids_bool
415-
416-
417354
class _LoadImageTilesByPos(torch.autograd.Function):
418355

419356
@staticmethod

0 commit comments

Comments
 (0)