Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_gpus when TP > 1 #1852

Merged
merged 3 commits into from
Dec 3, 2023
Merged

Fix num_gpus when TP > 1 #1852

merged 3 commits into from
Dec 3, 2023

Conversation

WoosukKwon
Copy link
Collaborator

Fixes #1851

@WoosukKwon
Copy link
Collaborator Author

@Yard1 Should we also fix this:

engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
?

@WoosukKwon WoosukKwon requested a review from Yard1 November 30, 2023 07:15
@Yard1
Copy link
Collaborator

Yard1 commented Nov 30, 2023

Yes, that line in AsyncEngine it should be set in the same way @WoosukKwon

@WoosukKwon
Copy link
Collaborator Author

I see. It seems tricky then...

@Yard1
Copy link
Collaborator

Yard1 commented Nov 30, 2023

Perhaps make it a method of ParallelConfig to get the number of GPUs for Ray?

@WoosukKwon
Copy link
Collaborator Author

Perhaps make it a method of ParallelConfig to get the number of GPUs for Ray?

@Yard1 Can we somehow keep ParallelConfig independent of Ray?

@Yard1
Copy link
Collaborator

Yard1 commented Dec 1, 2023

Have an independent function taking in the parallel config, then?

@FlorianJoncour
Copy link
Contributor

FlorianJoncour commented Dec 2, 2023

In my opinion its a regression compared to #1821

After further investigations it seems that Ray make CUDA fail with tensor parallelism if the total amount of GPU reservation is less than 1 (wich indeed no make sense).

Should we not raise and exception in that case ?

# Ensures that we reserves at least 1 GPU with tensor parallelism
if self.parallel_config.tensor_parallel_size > 1:
    gpu_res = (self.cache_config.gpu_memory_utilization * self.parallel_config.tensor_parallel_size)
    if gpu_res < (self.parallel_config.tensor_parallel_size-1):
        raise ValueError(f"`tensor_parallel_size` is set to {self.parallel_config.tensor_parallel_size} but the current placement group bundle reserves less than {self.parallel_config.tensor_parallel_size} GPU ({gpu_res}). "
                            f"Try increasing `gpu_memory_utilization` to at least 0.5. "
                            f"If you are using Ray Serve you may also need to set a correct `placement_group_bundles`")

And let

num_gpus=self.cache_config.gpu_memory_utilization

So it helps the user to set correct configuration and also let it use fractions of GPU with tensor parallelism.
I've tested with and without Ray Serve and we just have to set correct gpu_memory_utilization
And also placement_group_bundles with Ray serve.

@Yard1
Copy link
Collaborator

Yard1 commented Dec 3, 2023

So the problem here is that NCCL requires separate devices to operate, but Ray has no insight into that and will try to pack the placement group into as few GPUs as possible. If the number of GPUs Ray chooses is smaller than the tensor parallelism factor, NCCL will not work. There is no easy way to prevent that aside from just not using fractional GPUs if tensor parallelism factor is greater than 1.

@WoosukKwon
Copy link
Collaborator Author

WoosukKwon commented Dec 3, 2023

@Yard1 I'ved fixed async_llm_engine. While the fix is a bit hacky, I think this is acceptable since engine_use_ray is usually False. PTAL.

This was referenced Dec 3, 2023
@WoosukKwon WoosukKwon merged commit 464dd98 into main Dec 3, 2023
2 checks passed
@WoosukKwon WoosukKwon deleted the fix-tp-mem branch December 3, 2023 20:24
xjpang pushed a commit to xjpang/vllm that referenced this pull request Dec 4, 2023
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix num_gpus when TP > 1
3 participants