Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] avoid too many cuda context by caching p2p test #4021

Merged
merged 15 commits into from
Apr 14, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Apr 11, 2024

It is observed in #3821 that every worker takes memory in every GPU in _can_p2p , because the test will make all process allocate cuda context for every GPU, in total leading to $n * (n-1)$ cuda context.

To avoid this, we can cache the test of p2p test.

Before this PR (tp=4):

nvidia-smi
Thu Apr 11 15:43:50 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla V100-SXM2-32GB-LS        On  |   00000000:06:00.0 Off |                    0 |
| N/A   35C    P0             96W /  250W |   29767MiB /  32768MiB |     61%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB-LS        On  |   00000000:07:00.0 Off |                    0 |
| N/A   38C    P0             96W /  250W |   29609MiB /  32768MiB |     62%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  Tesla V100-SXM2-32GB-LS        On  |   00000000:0A:00.0 Off |                    0 |
| N/A   37C    P0             78W /  250W |   29717MiB /  32768MiB |     56%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  Tesla V100-SXM2-32GB-LS        On  |   00000000:0B:00.0 Off |                    0 |
| N/A   35C    P0             75W /  250W |   29657MiB /  32768MiB |     33%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  Tesla V100-SXM2-32GB-LS        On  |   00000000:85:00.0 Off |                    0 |
| N/A   28C    P0             40W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  Tesla V100-SXM2-32GB-LS        On  |   00000000:86:00.0 Off |                    0 |
| N/A   29C    P0             41W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  Tesla V100-SXM2-32GB-LS        On  |   00000000:89:00.0 Off |                    0 |
| N/A   32C    P0             41W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  Tesla V100-SXM2-32GB-LS        On  |   00000000:8A:00.0 Off |                    0 |
| N/A   31C    P0             47W /  250W |     343MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1448586      C   python                                      28810MiB |
|    0   N/A  N/A   1456481      C   ray::RayWorkerVllm                            306MiB |
|    0   N/A  N/A   1458504      C   ray::RayWorkerVllm                            306MiB |
|    0   N/A  N/A   1460718      C   ray::RayWorkerVllm                            306MiB |
|    1   N/A  N/A   1448586      C   python                                        306MiB |
|    1   N/A  N/A   1456481      C   ray::RayWorkerVllm                          28652MiB |
|    1   N/A  N/A   1458504      C   ray::RayWorkerVllm                            306MiB |
|    1   N/A  N/A   1460718      C   ray::RayWorkerVllm                            306MiB |
|    2   N/A  N/A   1448586      C   python                                        306MiB |
|    2   N/A  N/A   1456481      C   ray::RayWorkerVllm                            306MiB |
|    2   N/A  N/A   1458504      C   ray::RayWorkerVllm                          28760MiB |
|    2   N/A  N/A   1460718      C   ray::RayWorkerVllm                            306MiB |
|    3   N/A  N/A   1448586      C   python                                        306MiB |
|    3   N/A  N/A   1456481      C   ray::RayWorkerVllm                            306MiB |
|    3   N/A  N/A   1458504      C   ray::RayWorkerVllm                            306MiB |
|    3   N/A  N/A   1460718      C   ray::RayWorkerVllm                          28700MiB |
+-----------------------------------------------------------------------------------------+

GPU blocks: 11762, CPU blocks: 2048
Throughput: 8.17 requests/s, 3932.45 tokens/s

After this PR (tp=4):

✗ nvidia-smi
Thu Apr 11 15:52:19 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla V100-SXM2-32GB-LS        On  |   00000000:06:00.0 Off |                    0 |
| N/A   36C    P0             87W /  250W |   29717MiB /  32768MiB |     72%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB-LS        On  |   00000000:07:00.0 Off |                    0 |
| N/A   38C    P0             87W /  250W |   29559MiB /  32768MiB |     66%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  Tesla V100-SXM2-32GB-LS        On  |   00000000:0A:00.0 Off |                    0 |
| N/A   37C    P0             85W /  250W |   29667MiB /  32768MiB |     61%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  Tesla V100-SXM2-32GB-LS        On  |   00000000:0B:00.0 Off |                    0 |
| N/A   35C    P0             79W /  250W |   29607MiB /  32768MiB |     64%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  Tesla V100-SXM2-32GB-LS        On  |   00000000:85:00.0 Off |                    0 |
| N/A   28C    P0             40W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  Tesla V100-SXM2-32GB-LS        On  |   00000000:86:00.0 Off |                    0 |
| N/A   29C    P0             41W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  Tesla V100-SXM2-32GB-LS        On  |   00000000:89:00.0 Off |                    0 |
| N/A   30C    P0             41W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  Tesla V100-SXM2-32GB-LS        On  |   00000000:8A:00.0 Off |                    0 |
| N/A   29C    P0             42W /  250W |       0MiB /  32768MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   1491851      C   python                                      29706MiB |
|    1   N/A  N/A   1496915      C   ray::RayWorkerVllm                          29548MiB |
|    2   N/A  N/A   1497165      C   ray::RayWorkerVllm                          29656MiB |
|    3   N/A  N/A   1497328      C   ray::RayWorkerVllm                          29596MiB |
+-----------------------------------------------------------------------------------------+

GPU blocks: 12222, CPU blocks: 2048
Throughput: 8.21 requests/s, 3953.50 tokens/s

Conclusion

Fight really hard to save 306MiB * tp * tp memory.

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why do we write the cache to a file? Is it so that all the workers can access the cache info?

In this case, what's happening if there are 2 vllm instances running at the same time? (isn't the cache file overwriting each other, or is it just to check gpu to gpu accessibility, so that they can share the same result?)

vllm/distributed/utils.py Outdated Show resolved Hide resolved
vllm/distributed/utils.py Show resolved Hide resolved
vllm/distributed/utils.py Outdated Show resolved Hide resolved
vllm/distributed/utils.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member Author

why do we write the cache to a file? Is it so that all the workers can access the cache info?

If we don't cache it into a file, every master process still needs to initialize cuda context in every gpu, that will lead to tp * 306MiB memory cost. In one machine, the p2p access for a fixed set of GPUs should not change. So we can safely cache it into a file, and later even master process does not need to pay the cost.

what's happening if there are 2 vllm instances running at the same time?

Note that the cache file name is suffixed by cuda_visible_devices. So 2 vllm instances will not have conflict.

@WoosukKwon
Copy link
Collaborator

QQ: Can we directly use cudaIpcOpenMemHandle instead?

@youkaichao
Copy link
Member Author

QQ: Can we directly use cudaIpcOpenMemHandle instead?

We should seek help from @hanzhi713 , I'm not familiar with this :(

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 12, 2024

QQ: Can we directly use cudaIpcOpenMemHandle instead?

We should seek help from @hanzhi713 , I'm not familiar with this :(

We can use cudaDeviceCanAccessPeer, accessible with cupy: https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.runtime.deviceCanAccessPeer.html

@rkooo567
Copy link
Collaborator

Makes sense! Maybe it'd be great to have a simple comment in the function for the motivation of writing it to a file!

@youkaichao
Copy link
Member Author

Technically many functions can be used to detect p2p access, but the details of which function will cost additional cuda context is extemely vague. cudaIpcOpenMemHandle or cudaDeviceCanAccessPeer might work, but it is hard to say if they will place tp * tp cuda context.

That said, I think caching the result is a universal way, regardless of how we detect p2p access.

@youkaichao
Copy link
Member Author

Anyone stamp my PR? Or any additional modification required?

@esmeetu
Copy link
Collaborator

esmeetu commented Apr 12, 2024

  1. multi-node seems not working?
  2. Should we delete this cache after p2p test done?

@rkooo567
Copy link
Collaborator

Multi node is not working with vllm anyway, so not sure if we should handle it in this PR

@youkaichao
Copy link
Member Author

Multi node is not working with vllm anyway

This is not correct. vllm indeed supports multi-node.

@esmeetu I added the case for multi-node, by letting one process per node to create the cache (i.e. local_rank == 0). Could you please test if this works for multi-node setting?

@hanzhi713
Copy link
Contributor

Technically either cudaDeviceCanAccessPeer or cudaIpcOpenMemHandle would suffice. cudaDeviceCanAccessPeer is called by torch.cuda.can_device_access_peer. However, cuda driver sometimes is buggy and might report p2p being supported even though it's not. This can occur on 3090 and 4090. Thus, we need to perform actual p2p copies and check whether the result is correct to mitigate the driver bug.

@youkaichao
Copy link
Member Author

I can confirm cudaDeviceCanAccessPeer will cost about tp * tp * 300MB memory. Not sure if cudaIpcOpenMemHandle behaves the same.

Either way, the caching in this PR is reasonable. p2p access pattern between GPUs seldom change.

@esmeetu
Copy link
Collaborator

esmeetu commented Apr 14, 2024

Multi node is not working with vllm anyway

This is not correct. vllm indeed supports multi-node.

@esmeetu I added the case for multi-node, by letting one process per node to create the cache (i.e. local_rank == 0). Could you please test if this works for multi-node setting?

Multi-node was not supported when using custom-all-reduce. So this PR looks good to me.

@youkaichao youkaichao merged commit 2cd6b4f into vllm-project:main Apr 14, 2024
45 checks passed
@youkaichao youkaichao deleted the cache_p2p branch April 14, 2024 06:40
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: tp>1 every worker takes memory in every GPU after upgrade to 0.4.0
6 participants