Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: A Flexible Architecture for Distributed Inference #5775

Closed
youkaichao opened this issue Jun 24, 2024 · 7 comments
Closed

[RFC]: A Flexible Architecture for Distributed Inference #5775

youkaichao opened this issue Jun 24, 2024 · 7 comments
Labels

Comments

@youkaichao
Copy link
Member

youkaichao commented Jun 24, 2024

Motivation.

The current vLLM architecture for distributed inference is not flexible enough. We have a difficult time adding speculative decoding with a different tensor parallel size (see #5414 ). Quite the same problem happens when users want the vLLM processes to collaborate with additional processes, e.g. when RLHF frameworks want to sync weight with vLLM processes (see #5723 ). This RFC tries to improve the distributed inference architecture so that it is flexible enough to support more possibilities.

what's the difference between distributed training and distributed inference?

Distributed training is a well-studied area, with many optimized communication primitives vLLM already uses, such as allreduce. Distributed training usually happens at large scale, and follows the SPMD style code: all processes are running the same code. Datasets are sharded before training, iteration steps, batch size, model size and architecture... these are all known information for every processes. As a result, distributed training is essentially a for-loop, all the processes are virtually executing the same code:

for i, (batch, label) in enumerate(my_shard(dataset, batch_size, drop_last=True)):
    output = model(batch) # collective communication inside it
    loss = loss_fn(batch, label)
    loss.backward() # collective communication inside it
    model.update_weight() # collective communication inside it

Put it in a simple way, distributed training deals with static data, which either exists in the code as hyper-parameters (e.g. batchsize), or exists in the dataset and is sharded to every workers before training.

Distributed training trades off flexibility for scalability. For example, it is common for distributed training to use drop_last dataloader, which drops the last remainder batch with arbitrary size. This is because, they cannot shard, say, batchsize 7, to, say, data parallel size 16. All workers execute the same code, and they expect the same amount of work.

Things become different when it comes to distributed inference: we need to deal with dynamic data. We either deal with data from web requests, or appear as an engine object living in users' process, dealing with data users might send at any time. In either case, this is an RPC problem, there will be a driver process to receive and process data, send the data to workers to kick off the inference job, and pass the output to users.

the ideal architecture for distributed inference

Ideally, the architecture of distributed inference should look like this:

image

There would be a server process, dispatching requests to different models. Each model will occupy its own process group, and will not have any interference with the rest models. If they want to interact, we can also stitch the communication via the server process. Or we can discuss how these models can interact without the intervene of the server process.

In this architecture, what would model A process group do when model B process group is executing requests? What would model A process group and model A process group do when the server process is idle (no requests)? Afterall, a process must execute something from operating system's perspective. The answer is, they should be in a dead loop waiting for commands. That's basically when RPC (remote procedure call) comes into play: we want the server process to execute functions (procedures) inside the model process. To achieve this goal:

  • the server process must establish some connection with the model process
  • the model process will be in a dead-loop, waiting to receive commands from the server process
  • for each rpc call, the server process serializes arguments, send the arguments and the command to model process, and the model process deserializes arguments, executes the function, gets the output, serializes the output, send the output back to server process, and the server process deserializes the output.

Put it in a simple way, we need to answer the following questions for distributed inference:

  • How to launch model process group?
  • What would they do when the server is idle?
  • How would they know there are requests coming?
  • How would they get the arguments (the current requests to run)?
  • How would they return the outputs to the server?

With these questions in mind, let's continue the discussion.

the evolution of vLLM's architecture for distributed inference

ray rpc call, ray argument passing

Prior to #2221 , vLLM's architecture for distributed inference was quite similar to the figure above. The critical part of code is:

image image

The problem of this code, is that the server process sequentially sends all the arguments (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) to each worker. Because seq_group_metadata_list can be quite large, the cost of launching these rpc calls, is very expensive. What's more, workers are executing tensor parallel programs, which involves blocking collective communication such as allreduce. They cannot proceed if any worker is not ready.

image

When tensor parallel size is large, the latency of rpc call would increase proportionally, even exceeding the model execution time itself. Therefore, this approach is not efficient.

ray rpc call, broadcast for argument passing

To avoid the ray overhead for argument passing, #2221 changes the architecture as follows:

image

It merges server process and the rank 0 TP process into the same process. As a result, rank 0 in TP natually has the argument. It still uses ray to launch the rpc call in the rest workers:

image

non-driver worker will receive None as the arguments:

image

the real arguments for non-driver workers are retrieved from the broadcast operation, with the driver worker as the broadcast source.

In this case, ray is responsible for launching the rpc call, passing empty arguments, and all non-driver workers simultaneously receive arguments from the driver, thus reducing the latency of the whole rpc calls. The benefit is quite significant, #2221 reports over 50% throughput improvement.

The performance gain comes with cost of flexibility. Since we merge the rank 0 TP process with the server process, we cannot easily create another model TP process group easily. Since we use broadcast for argument passing, with mixed cpu tensor and gpu tensor, we manually insert a blocking operation for model process group, which cannot extend to pipeline parallel.

broadcast for both rpc call and argument passing

When we execute model, we use ray to launch rpc calls, although the "arguments" we pass are empty and thus small in size, ray rpc calls are still sequential. Then comes with #4894 , putting workers in a dead-loop with broadcasting. It exploits the fact that, if the source of broadcast is not called, the receiver of broadcast will be blocked and wait for the data. Therefore, we kind of abuse the broadcast semantic, to use it for simultaneously calling function and passing arguments.

image

Because broadcast is not designed to wait indefinitely, it is prone to timeout. I suspect it can cause some weird issues (not confirmed yet), where users might send a request every half an hour to keep the server alive.

optimizing broadcast

broadcast operation is originally designed for passing tensors. The sender and receiver needs to know the size of broadcast in advance. To use it for argument passing, we use a trick to first pass a size-1 int tensor to indicate the size, and then pass a tensor of that size. #4440 discussed it in details.

To optimize the cost, #5399 added a shared memory transport, which achieves the same effect of passing Python objects across processes, without multiple call of broadcast.

NOTE: currently we use this shared memory transport to pass Python objects, and still use gloo/nccl to pass CPU/GPU tensors.

Proposed Change.

message-queue based async batch rpc and argument passing

After we know the evolution of vLLM's architecture, it becomes clear that, what we want is a good rpc mechanism with:

  • batched rpc call, the cost of kicking off n rpc calls is the same as kicking off 1 rpc calls
  • batched argument passing, the cost of passing the same arguments to n processes is the same as passing them to 1 processes.
  • async rpc call, after we kick off the rpc calls, results are stored somewhere, and we can fetch them later

An architecture with the above benefits, does exist. One possible way I can come up with, is message-queue:

image

The defining characteristic, is that we use a single-publisher-multiple-subscriber message queue for the input part. The publisher (server process) only needs to publish once (enqueue the arguments), and all the subscriber (model processes) can read the message independently. In contrast, if we use a naive queue, the server process needs to enqueue the arguments n times, and needs to guarantee that no workers read again before other workers finish reading.

Inside one machine, the solution is to use shared memory: the publisher writes data to shared memory, and all the worker processes can directly read. This tool is already available in #5399 .

Across multiple machines, the solution might be to set up a message server, the publisher writes data to the message server (living in another thread), the server either pushes all the data to workers, or wait for workers pull data. In either case, the server process only needs to write the data once. This part is ongoing in #5755 .

NOTE: we can also use a mix of both: the server process sends data to shared memory for worker inside the node, and sends data to message server for worker living in the rest nodes. In this case, we can kick off the workers in the same node faster, which might help pipeline parallel (the first pipeline stage can be scheduled to live in the same node as the server process).

what about the output?

All of discussions in this RFC focus on how we pass the input. We rarely talk about collecting the output from workers. This is because, we want to send inputs to GPU as quickly as possible to saturate GPU performance. The server process can retrieve the output when it is idle (has no more input to feed to GPU). For this part, using a simple queue should be enough.

what about functions other than execute_model, e.g. add_lora?

We can turn the architecture as a full-fledged rpc architecture. The args we pass to workers, can be a scheme like:

{
"method": "execute_model",
"args": args,
"kwargs": kwargs,
}

And the method can be add_lora, too.

how far are we away from the desired architecture?

The main challenge, is that we need to re-design the input sent to worker processes. Currently it contains seq_group_metadata_list, which is quite large. Ideally, we should only send token ids to workers, and workers own their kv cache, block tables. With this, we can even incrementally send token ids, to further reduce the amount of data transferred between server process and worker processes.

Essentially, we will move all heavy part of code (scheduler, block manager) into worker, and keep the server process as lightweight as possible.

Feedback Period.

June 23th - June 30th, we can extend the discussion if needed.

CC List.

The full cc list will be very long. It involves almost all aspects of vLLM. Below is an incomplete list:

cc @zhuohan123 @njhill who made the evolution possible.
cc @WoosukKwon I think this should unify the role of executors across different hardware backends. After we finish this re-arch, they should only touch worker_xxx.py and model_runner_xxx.py.
cc @simon-mo moving scheduler and block manager into worker will make it possible to handle complicated memory management without server process in the critical path.
cc @KuntaiDu this architecture might make it easier for prefill/decode disaggregation. we can have prefill worker group and decode worker group.
cc @LiuXiaoxuanPKU @cadedaniel for spec decode
cc @stephanie-wang @Yard1 @rkooo567 for ray dag, very related acceleration technique.
cc @andoorve for pipeline parallel

cc ...
Feel free to add someone by adding comments.

Any Other Things.

Although this architecture is good in my opinion, we should achieve it step-by-step. i.e., we cannot sacrifice performance or usability during the re-arch. If we can do it the right way, I believe we can even improve performance and usability during the re-arch.

@zhuohan123
Copy link
Member

zhuohan123 commented Jun 24, 2024

Thanks for the great writeup! Please find my comment below inline:

A random thought: Can we just use redis here? Or will it be too slow because it may bring in one additional data copy?

what about the output?

All of discussions in this RFC focus on how we pass the input. We rarely talk about collecting the output from workers. This is because, we want to send inputs to GPU as quickly as possible to saturate GPU performance. The server process can retrieve the output when it is idle (has no more input to feed to GPU). For this part, using a simple queue should be enough.

I'm not very sure about this. LLM inference is an autoregressive process and the output may be used as input at the next step. How do you think about this?

what about functions other than execute_model, e.g. add_lora?

We can turn the architecture as a full-fledged rpc architecture. The args we pass to workers, can be a scheme like:

{
"method": "execute_model",
"args": args,
"kwargs": kwargs,
}

And the method can be add_lora, too.

how far are we away from the desired architecture?

The main challenge, is that we need to re-design the input sent to worker processes. Currently it contains seq_group_metadata_list, which is quite large. Ideally, we should only send token ids to workers, and workers own their kv cache, block tables. With this, we can even incrementally send token ids, to further reduce the amount of data transferred between server process and worker processes.

Essentially, we will move all heavy part of code (scheduler, block manager) into worker, and keep the server process as lightweight as possible.

For this part I'm confused. I think our main refactoring goal is to make the GPU always busy without waiting for CPU execution. If you move all the heavy management logic into the worker - which owns the GPU - then will we face the same problem that the GPU worker still needs to wait the scheduler part in the worker to finish before it runs something on GPU.

My original thought is to keep the scheduler and block manager in the driver, but make them asynchronous to the GPU worker. In this case, the driver can schedule the next step for the worker while the worker is running the current step on GPU:

-----------------------------------------------time------------------------------------------------>
Driver | schedule step 1 | schedule step 2 |     | schedule step 3 |    | schedule step 4 | ...
Worker |       idle      |       run step 1      |       run step 2     |       run step 3     | ...

How do you think about the performance and the GPU utilization of this RFC?

Feedback Period.

June 23th - June 30th, we can extend the discussion if needed.

CC List.

The full cc list will be very long. It involves almost all aspects of vLLM. Below is an incomplete list:

cc @zhuohan123 @njhill who made the evolution possible. cc @WoosukKwon I think this should unify the role of executors across different hardware backends. After we finish this re-arch, they should only touch worker_xxx.py and model_runner_xxx.py. cc @simon-mo moving scheduler and block manager into worker will make it possible to handle complicated memory management without server process in the critical path. cc @KuntaiDu this architecture might make it easier for prefill/decode disaggregation. we can have prefill worker group and decode worker group. cc @LiuXiaoxuanPKU @cadedaniel for spec decode cc @stephanie-wang @Yard1 @rkooo567 for ray dag, very related acceleration technique. cc @andoorve for pipeline parallel

cc ... Feel free to add someone by adding comments.

Any Other Things.

Although this architecture is good in my opinion, we should achieve it step-by-step. i.e., we cannot sacrifice performance or usability during the re-arch. If we can do it the right way, I believe we can even improve performance and usability during the re-arch.

@youkaichao
Copy link
Member Author

NOTE: to keep the readability, please DM me for edit/typos.

@Jeffwan
Copy link
Contributor

Jeffwan commented Jun 25, 2024

I suggest to move the content to a google doc for review and discussion. You can also submit a PR instead.

@youkaichao
Copy link
Member Author

I suggest to move the content to a google doc for review and discussion.

thanks, will do.

@njhill
Copy link
Member

njhill commented Jun 26, 2024

Then comes with #4894 , putting workers in a dead-loop with broadcasting. It exploits the fact that, if the source of broadcast is not called, the receiver of broadcast will be blocked and wait for the data. Therefore, we kind of abuse the broadcast semantic, to use it for simultaneously calling function and passing arguments.
Because broadcast is not designed to wait indefinitely, it is prone to timeout. I suspect it can cause some weird #5084 (comment) (not confirmed yet), where users might send a request every half an hour to keep the server alive.

Just to clarify here, the workers only remain in the broadcast loop while processing sequences, they do not remain blocked on broadcast when there are no more sequences to process. So we should never wait on the collective ops indefinitely and I wouldn't consider this use an abuse :)

The referenced issue seems to be related (though I have been unable to reproduce so far), but must be some kind of bug to be fixed rather than inherent problem in the approach.

@youkaichao
Copy link
Member Author

This RFC is intended for discussion rather than immediate actions. It seems it causes some confusions. I will close it first, convert it to a design doc, and ask for comments there.

Sorry for the confusion.

@youkaichao
Copy link
Member Author

Can we just use redis here? Or will it be too slow because it may bring in one additional data copy?

@zhuohan123 I did some early prototyping. Redis is a system-level package, and cannot be installed via pip (we can only install a redis client). One more thing, I think it might be slower than zmq, which operates on raw socket level. The fastest implementation of broadcast is still through shared memory, but is only applicable in single node.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants