-
Notifications
You must be signed in to change notification settings - Fork 594
Support Multi-GPU training based on the paper "On Scaling Up 3D Gaussian Splatting Training" #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
WIP on cleaning up |
|
It is strange that For example, Egypt dataset from
|
As I have noted, I'm NOT comparing between the two as they are clearly in different settings (the number of GSs converged are different) |
| # Distribute the GSs to different ranks (also works for single rank) | ||
| points = points[world_rank::world_size] | ||
| rgbs = rgbs[world_rank::world_size] | ||
| scales = scales[world_rank::world_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TarzanZhao One question here. Any idea on the best way to split the GSs for multi-GPU initialization? I'm currently essentially randomly split them, which might not ideal for minimizing data transfer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our Grendel code, we just assign each GPU a contiguous chunk of the point cloud from colmap. You can try our implementation here: https://github.com/nyu-systems/Grendel-GS/blob/0ea84e456d58946aa9708e1932d7c3466edd6a98/scene/gaussian_model.py#L181
Neither our current solution nor yours may be optimal. However, because our grendel implements load balancing techniques for Gaussian distributions during training, an uneven split at the beginning does not significantly impact the overall training speed.
gsplat/rendering.py
Outdated
| if distributed: | ||
| world_rank = torch.distributed.get_rank() | ||
| world_size = torch.distributed.get_world_size() | ||
| N_world = [None] * world_size | ||
| C_world = [None] * world_size | ||
| torch.distributed.all_gather_object(N_world, N) | ||
| torch.distributed.all_gather_object(C_world, C) | ||
|
|
||
| # [TODO] `all_gather` is not differentiable w.r.t. viewmats and Ks | ||
| out_tensor_list = [torch.empty((C_i, 4, 4), device=device) for C_i in C_world] | ||
| torch.distributed.all_gather(out_tensor_list, viewmats.contiguous()) | ||
| viewmats = torch.cat(out_tensor_list, dim=0) | ||
| out_tensor_list = [torch.empty((C_i, 3, 3), device=device) for C_i in C_world] | ||
| torch.distributed.all_gather(out_tensor_list, Ks.contiguous()) | ||
| Ks = torch.cat(out_tensor_list, dim=0) | ||
| C = len(viewmats) # Silently change C from local #Cameras to global #Cameras. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TarzanZhao The current design of this API is that it takes in the local cameras that needs to be rendered by this rank. So before the projection stage, an all_gather needs to be done for each rank to get access to all cameras.
I guess another option for the API design is that this function takes in the "global" cameras (all cameras needed to be rendered by all ranks), which would avoid this all_gather operation. But this would mean that user have to load the same global set of cameras in each rank, but only supervise on a subset of them, which feels a bit counter intuitive from the user experience perspective.
I personally prefer the first design choice but that means two all_gather and sync operations here. Do you think the affect of this would be large in your experience?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using all_gather_object() is often slow in my experience because it transfers data from the CPU to the GPU, then to another GPU, and finally back to the CPU. It's advisable to avoid using this function as much as possible. If you prefer the first design, then please try to avoid using all_gather_object() at least.
In our Grendel's implementation, every gpu keep the same dataset and uses the same random seed to generate the same batch across GPU every time. This is the other option that you have mentioned. Then there is no need for all-gather cameras.
| # on which cameras they are visible to, which we already figured out in the projection | ||
| # stage. | ||
| if distributed: | ||
| if packed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@TarzanZhao This is the sparse all to all logic, which we called packed mode (only visible GSs are returned from the projection function).
I'm not sure if I'm implementing this part in a most efficient way. Would love to get your thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- depth and radii do not have to use functional version of all2all.
- collected_cnts does not need functional version of all2all.
- [cnt.item() for cnt in cnts] will invoke many times of GPU-CPU communication. The total latency will be large. You can change to transfer cnts to cpu by single call.
In general, this code has more kernels than ours: https://github.com/nyu-systems/Grendel-GS/blob/0ea84e456d58946aa9708e1932d7c3466edd6a98/gaussian_renderer/__init__.py#L176
More kernels will add significant kernel launch overheads, especially when images have small resolutions.
|
@TarzanZhao My biggest problem with this PR in its current stage is that it's far from getting 3.5x speedup on 4x GPUs. I'm not sure how much NVlinks play a role in this, which I don't have in my server setup. But would love to know if any thing can be improved on the implementation side, which I think its mostly just the update in the |
|
I think I can live with like ~3x speedup on 4 GPUs because I'm not implementing the tile-based balancing and GS rebalancing logic. Is that a reasonable expectation? |
|
Using packed mode:
CUDA_VISIBLE_DEVICES=4,5,6,7 python simple_trainer.py --steps_scaler 0.25 --eval_steps -1 --disable_viewer --packed
>>> Step: 7499 {'mem': 2.944546699523926, 'ellipse_time': 1661.8717801570892, 'num_GS': 1356072}
>>> Step: 7499 {'mem': 3.056483268737793, 'ellipse_time': 1661.9940576553345, 'num_GS': 1432245}
>>> Step: 7499 {'mem': 3.076539993286133, 'ellipse_time': 1661.1862392425537, 'num_GS': 1430476}
>>> Step: 7499 {'mem': 3.133754253387451, 'ellipse_time': 1661.065690279007, 'num_GS': 1470615}
>>> PSNR: 27.431, SSIM: 0.8689, LPIPS: 0.075 Time: 0.087s/image Number of GS: 1356072
CUDA_VISIBLE_DEVICES=4,5 python simple_trainer.py --steps_scaler 0.25 --eval_steps -1 --disable_viewer --packed --batch_size 2
>>> Step: 7499 {'mem': 6.096761226654053, 'ellipse_time': 1792.5143909454346, 'num_GS': 2868191}
>>> Step: 7499 {'mem': 6.014559268951416, 'ellipse_time': 1791.8004696369171, 'num_GS': 2786602}
>>> PSNR: 27.403, SSIM: 0.8682, LPIPS: 0.075 Time: 0.075s/image Number of GS: 2786602 |
|
As a reference. Running the official repo in the same GPU environment (without NVlinks) get:
Note the LPIPS is not the same LPIPS as in this repo. |
TarzanZhao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!! I'm very excited that our paper's distributed strategy can be adopted by gsplat so quickly. Thanks so much! I leave some comments and reference code and hope these can help a little bit.
| # Distribute the GSs to different ranks (also works for single rank) | ||
| points = points[world_rank::world_size] | ||
| rgbs = rgbs[world_rank::world_size] | ||
| scales = scales[world_rank::world_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our Grendel code, we just assign each GPU a contiguous chunk of the point cloud from colmap. You can try our implementation here: https://github.com/nyu-systems/Grendel-GS/blob/0ea84e456d58946aa9708e1932d7c3466edd6a98/scene/gaussian_model.py#L181
Neither our current solution nor yours may be optimal. However, because our grendel implements load balancing techniques for Gaussian distributions during training, an uneven split at the beginning does not significantly impact the overall training speed.
gsplat/rendering.py
Outdated
| if distributed: | ||
| world_rank = torch.distributed.get_rank() | ||
| world_size = torch.distributed.get_world_size() | ||
| N_world = [None] * world_size | ||
| C_world = [None] * world_size | ||
| torch.distributed.all_gather_object(N_world, N) | ||
| torch.distributed.all_gather_object(C_world, C) | ||
|
|
||
| # [TODO] `all_gather` is not differentiable w.r.t. viewmats and Ks | ||
| out_tensor_list = [torch.empty((C_i, 4, 4), device=device) for C_i in C_world] | ||
| torch.distributed.all_gather(out_tensor_list, viewmats.contiguous()) | ||
| viewmats = torch.cat(out_tensor_list, dim=0) | ||
| out_tensor_list = [torch.empty((C_i, 3, 3), device=device) for C_i in C_world] | ||
| torch.distributed.all_gather(out_tensor_list, Ks.contiguous()) | ||
| Ks = torch.cat(out_tensor_list, dim=0) | ||
| C = len(viewmats) # Silently change C from local #Cameras to global #Cameras. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using all_gather_object() is often slow in my experience because it transfers data from the CPU to the GPU, then to another GPU, and finally back to the CPU. It's advisable to avoid using this function as much as possible. If you prefer the first design, then please try to avoid using all_gather_object() at least.
In our Grendel's implementation, every gpu keep the same dataset and uses the same random seed to generate the same batch across GPU every time. This is the other option that you have mentioned. Then there is no need for all-gather cameras.
| # on which cameras they are visible to, which we already figured out in the projection | ||
| # stage. | ||
| if distributed: | ||
| if packed: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- depth and radii do not have to use functional version of all2all.
- collected_cnts does not need functional version of all2all.
- [cnt.item() for cnt in cnts] will invoke many times of GPU-CPU communication. The total latency will be large. You can change to transfer cnts to cpu by single call.
In general, this code has more kernels than ours: https://github.com/nyu-systems/Grendel-GS/blob/0ea84e456d58946aa9708e1932d7c3466edd6a98/gaussian_renderer/__init__.py#L176
More kernels will add significant kernel launch overheads, especially when images have small resolutions.
gsplat/rendering.py
Outdated
| ) # [C_i, N, :] | ||
|
|
||
| # collected contains: | ||
| radii = collected[..., 0].int() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use torch.split() to avoid many kernel launches here.
Overall, I think there are two major reasons:
|
|
After some optimization, the training time of 4 GPUs is brought down to 16m25s!
CUDA_VISIBLE_DEVICES=4,5,6,7 python simple_trainer.py --steps_scaler 0.25 --eval_steps -1 --disable_viewer --packed
>>> Step: 7499 {'mem': 2.8663759231567383, 'ellipse_time': 985.6521210670471, 'num_GS': 1352151}
>>> Step: 7499 {'mem': 2.971991539001465, 'ellipse_time': 984.3997764587402, 'num_GS': 1406792}
>>> Step: 7499 {'mem': 3.0060572624206543, 'ellipse_time': 985.4934339523315, 'num_GS': 1454676}
>>> Step: 7499 {'mem': 3.0007195472717285, 'ellipse_time': 985.5467941761017, 'num_GS': 1428084}
CUDA_VISIBLE_DEVICES=4,5 python simple_trainer.py --steps_scaler 0.25 --eval_steps -1 --disable_viewer --packed --batch_size 2
>>> Step: 7499 {'mem': 5.9516401290893555, 'ellipse_time': 1507.71635222435, 'num_GS': 2866948}
>>> Step: 7499 {'mem': 5.851601600646973, 'ellipse_time': 1506.6484203338623, 'num_GS': 2769727} |
|
MCMC on 4 GPUs CUDA_VISIBLE_DEVICES=4,5,6,7 python simple_trainer_mcmc.py --steps_scaler 0.25 --eval_steps -1 --packed
>>> Step: 7499 {'mem': 2.4360504150390625, 'ellipse_time': 991.8759768009186, 'num_GS': 1000000}
>>> Step: 7499 {'mem': 2.403714179992676, 'ellipse_time': 992.4375550746918, 'num_GS': 1000000}
>>> Step: 7499 {'mem': 2.42301607131958, 'ellipse_time': 992.5172688961029, 'num_GS': 1000000}
>>> Step: 7499 {'mem': 2.411731243133545, 'ellipse_time': 991.4808518886566, 'num_GS': 1000000}
PSNR: 27.714, SSIM: 0.8744, LPIPS: 0.079 Time: 0.074s/image Number of GS: 1000000 |
|
This PR is great! I wonder that can it be used in a way to use 1 GPU to process multiple batch? |
|
@Ben-Mack For a single GPU, I guess there isn't much you can do other than looping over images? |
…ian Splatting Training" (nerfstudio-project#253) * checkin the code * nicer API * mcmc script now can works with multigpu * trainer supports multi gpu * get rid of reduduant code * func doc * support packed mode * format * more exp * multi GPU viewer * optim * cleanup * cleanup * merge main * MCMC * doc * scripts * scripts and performance --------- Co-authored-by: Ruilong Li <397653553@qq.com>
Paper link: https://daohanlu.github.io/scaling-up-3dgs/
Latest results:
Scripts:
bash benchmarks/basic_4gpus.sh