Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism vs Data Parallelism #367

Closed
xcxhy opened this issue Jul 5, 2023 · 14 comments
Closed

Tensor Parallelism vs Data Parallelism #367

xcxhy opened this issue Jul 5, 2023 · 14 comments

Comments

@xcxhy
Copy link

xcxhy commented Jul 5, 2023

Hi, thanks! I use vllm to inference the llama-7B model on single gpu, and tensor-parallel on 2-gpus and 4-gpus, we found that it is 10 times faster than HF on a single GPU, but using tensor parallelism, there is no significant increase in token throughput.We understand that through data parallelism, the memory can be expanded and the batch of processing samples can be increased.But the communication between graphics cards may reduce the speed. If 2-gpus is used, there should be an acceleration of 1.5✖️. But now the throughput has basically remained unchanged. Is it because our GPU KV cache usage is full, or there are other reasons. Looking forward to your reply!

@xcxhy
Copy link
Author

xcxhy commented Jul 7, 2023

I also want to know whether the tensor parallelism you use is 1D, 2D or 3D?

@PythonNut
Copy link

PythonNut commented Jul 10, 2023

As an example of this, I ran the following script with different values of WORLD_SIZE:

import torch, time, tqdm                                                                                                                                                                                                                                                                                                                                                   
from vllm import LLM, SamplingParams                                                                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                           
WORLD_SIZE = 1                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
BATCH_SIZE = 2048                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
llm = LLM(                                                                                                                                                                                                                                                                                                                                                                 
    model="lmsys/vicuna-7b-v1.3",                                                                                                                                                                                                                                                                                                                                          
    tokenizer="hf-internal-testing/llama-tokenizer",                                                                                                                                                                                                                                                                                                                       
    tensor_parallel_size=WORLD_SIZE,                                                                                                                                                                                                                                                                                                                                       
    gpu_memory_utilization=0.85                                                                                                                                                                                                                                                                                                                                            
)                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                           
start = time.perf_counter()                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                           
batch = torch.randint(32000, (BATCH_SIZE, 120))                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                           
out = llm.generate(                                                                                                                                                                                                                                                                                                                                                        
          prompt_token_ids=[tokens.tolist() for tokens in batch],                                                                                                                                                                                                        
          use_tqdm=False,                                                                                                                                                                                                      
          sampling_params=SamplingParams(                                                                                                                                                                                                         
              max_tokens=20,                                                                                                                                                                                                               
              ignore_eos=True,                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
          )                                                                                                                                                                                                                                                                                                                                                                
  )                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                           
print(time.perf_counter() - start)        

My results were:

  • 1x A100 40G: 33.2 seconds
  • 2x A100 40G: 28.3 seconds
  • 4x A100 40G: 25.6 seconds
  • 8x A100 40G: 27.8 seconds
  • 16x A100 40G: 36.3 seconds

So there is very little gain from parallelism (and too many GPUs actually make things run slower!).

@irasin
Copy link
Contributor

irasin commented Jul 11, 2023

I also want to know whether the tensor parallelism you use is 1D, 2D or 3D?

I think vLLM only use tensor parallel 1D.

@xcxhy
Copy link
Author

xcxhy commented Jul 11, 2023

@PythonNut I think your max-tokens set too small, encountered the bottleneck of A100 computing power. vllm is dynamic batch to inference, If you encounter a memory bottleneck, the speed can be improved by using two cards, because the computing power is not fully used. If the computing power encounters a bottleneck, it will slow down due to the communication between multiple cards. This is what I guessed from my own testing analysis.

@xcxhy
Copy link
Author

xcxhy commented Jul 11, 2023

@irasin I think so, the communication between each layer cause the slower speed.

@PythonNut
Copy link

@xcxhy

Right I guess I'm showing a case where tensor parallelism is not the best form of parallelism. In particular, I am measuring throughput not latency; I specifically chose a batch size that is too big to fit in RAM so there are many iterations of inference happening inside generate. Since I am very happy with the throughput on a single GPU, it seems like data parallelism would scale better in this situation than tensor parallelism.

@irasin
Copy link
Contributor

irasin commented Jul 12, 2023

@xcxhy

Right I guess I'm showing a case where tensor parallelism is not the best form of parallelism. In particular, I am measuring throughput not latency; I specifically chose a batch size that is too big to fit in RAM so there are many iterations of inference happening inside generate. Since I am very happy with the throughput on a single GPU, it seems like data parallelism would scale better in this situation than tensor parallelism.

Hi, @PythonNut, I wonder have you tested pipeline parallel vs tensor parallel in some other frameworks?

@xcxhy
Copy link
Author

xcxhy commented Jul 12, 2023

@PythonNut I agree with your opinion. My test data shows that the length of each text is very long, so that the video memory of a single GPU cannot be put into enough batches. At this time, using tensor parallel will speed up to a certain extent. I think you can test it, set the batch to 10, assuming that one gpu can only accommodate 10, then you use 2 gpus, set the batch to 20, this should be able to verify whether tensor parallel can provide acceleration.

@zhuohan123
Copy link
Member

@xcxhy This is actually due to python overhead, which bottlenecks the tensor parallel performance for smaller models. We are thinking about implementing models in C++ to overcome this issue. You can also track the progress in #42

@Peter-Devine
Copy link

Peter-Devine commented Nov 1, 2023

Just dropping this as a simple example of how to do data parallel inference in Python that I've found to be effective.

Obviously I'd appreciate the eventual implementation of proper data parallel processing in the actual package, but this works decently as a stop-gap just now.

import os
import multiprocessing
from vllm import LLM, SamplingParams

NUM_GPUS = 4

def run_inference_one_gpu(gpu_id, prompt_list, model_name, sampling_params):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    llm = LLM(model=model_name)
    return llm.generate(prompt_list, sampling_params)

# Splits a list into roughly equally sized pieces
# split_list(["a", "b", "c", "d", "e", "f", "g"], 3) -> [['a', 'b'], ['c', 'd'], ['e', 'f', 'g']]
split_list = lambda l, n: [l[i * len(l) // n: (i + 1) * len(l) // n] for i in range(n)]

def run_inference_multi_gpu(model_name, prompts, sampling_params):
    split_prompts = split_list(prompts, NUM_GPUS)
    inputs = [(i, p, model_name, sampling_params) for i, p in enumerate(split_prompts)]

    with multiprocessing.Pool(processes=NUM_GPUS) as pool:
        results = pool.starmap(run_inference_one_gpu, inputs)

    outputs = []
    for result in results:
        outputs.extend(result)

    return outputs

prompts = [f"Write me a story about why {i} is your favourite number.\n\n{i} is my favourite number because " for i in range(100)]
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
model_name = "mistralai/Mistral-7B-v0.1"
outputs = run_inference_multi_gpu(model_name, prompts, sampling_params)
print(outputs)

This requires the whole model to be able to fit on to one GPU (as per data parallel's usual implementation) and will doubtless have a higher RAM overhead (I haven't checked, but it shouldn't be massive depending on your text size), but it does run seem to run at roughly N times the speed of running on one GPU (where N=number of GPUs) compared to <N times for the tensor parallel implementation.

I hope this is useful to someone!

@appliedml42
Copy link

@xcxhy

I think I can confirm your intuition with data. I have 4 A6000 cards with much lower IO bandwidth than the A100 cards so this example is big enough to IO bottleneck and not compute bottleneck. Same example with vLLM 0.2.2 on Nvidia PyTorch container 23.10

1 GPU: 47.28260640997905 secs
2 GPU: 38.49804261501413 secs
4 GPU: 35.09966081101447 secs

Best,
Abhishek

@SunLemuria
Copy link

I think fastchat supports this feature: fastchat scalability
image

@yhyu13
Copy link

yhyu13 commented Dec 20, 2023

I think fastchat supports this feature: fastchat scalability image

How to do it natively to vllm project rather than fastchat?

@hmellor
Copy link
Collaborator

hmellor commented Mar 8, 2024

Closing this in favour of #689

@hmellor hmellor closed this as completed Mar 8, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this issue Oct 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants