@@ -351,69 +351,6 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
351
351
cuda_args
352
352
)
353
353
354
- def get_local2j_ids (self , means2D , radii , cuda_args ):
355
- # For each 3dgs, calculate the set of GPUs that will use this 3dgs for rendering.
356
-
357
- if isinstance (cuda_args ["dist_global_strategy" ], str ):
358
- raster_settings = self .raster_settings
359
- mp_world_size = int (cuda_args ["mp_world_size" ])
360
- mp_rank = int (cuda_args ["mp_rank" ])
361
-
362
- # TODO: make it more general.
363
- dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
364
- assert len (dist_global_strategy ) == mp_world_size + 1 , "dist_global_strategy should have length WORLD_SIZE+1"
365
- assert dist_global_strategy [0 ] == 0 , "dist_global_strategy[0] should be 0"
366
- dist_global_strategy = torch .tensor (dist_global_strategy , dtype = torch .int , device = means2D .device )
367
-
368
- args = (
369
- raster_settings .image_height ,
370
- raster_settings .image_width ,
371
- mp_rank ,
372
- mp_world_size ,
373
- means2D ,
374
- radii ,
375
- dist_global_strategy ,
376
- cuda_args
377
- )
378
-
379
- local2j_ids_bool = _C .get_local2j_ids_bool (* args ) # local2j_ids_bool is (P, world_size) bool tensor
380
-
381
- else :
382
- raster_settings = self .raster_settings
383
- mp_world_size = int (cuda_args ["mp_world_size" ])
384
- mp_rank = int (cuda_args ["mp_rank" ])
385
-
386
- division_pos = cuda_args ["dist_global_strategy" ]
387
- division_pos_xs , division_pos_ys = division_pos
388
-
389
- rectangles = []
390
- for y_rank in range (len (division_pos_ys [0 ])- 1 ):
391
- for x_rank in range (len (division_pos_ys )):
392
- local_tile_x_l , local_tile_x_r = division_pos_xs [x_rank ], division_pos_xs [x_rank + 1 ]
393
- local_tile_y_l , local_tile_y_r = division_pos_ys [x_rank ][y_rank ], division_pos_ys [x_rank ][y_rank + 1 ]
394
- rectangles .append ([local_tile_y_l , local_tile_y_r , local_tile_x_l , local_tile_x_r ])
395
- rectangles = torch .tensor (rectangles , dtype = torch .int , device = means2D .device )# (mp_world_size, 4)
396
-
397
- args = (
398
- raster_settings .image_height ,
399
- raster_settings .image_width ,
400
- mp_rank ,
401
- mp_world_size ,
402
- means2D ,
403
- radii ,
404
- rectangles ,
405
- cuda_args
406
- )
407
-
408
- local2j_ids_bool = _C .get_local2j_ids_bool_adjust_mode6 (* args ) # local2j_ids_bool is (P, world_size) bool tensor
409
-
410
- local2j_ids = []
411
- for rk in range (mp_world_size ):
412
- local2j_ids .append (local2j_ids_bool [:, rk ].nonzero ())
413
-
414
- return local2j_ids , local2j_ids_bool
415
-
416
-
417
354
class _LoadImageTilesByPos (torch .autograd .Function ):
418
355
419
356
@staticmethod
0 commit comments