-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Interface and Abstraction for Distributed Inference Environment #3587
Comments
Thanks for the RFC! I have several questions.
|
It is possible, and it won't change the interface in this RFC. The burden would go to
I think it should go to the API of a distributed worker, an abstraction coming with launcher. Each A default implementation of class DistributedWorker:
def __init__(self, args):
self.coor = Coordinator(); # initialize Coordinator, done by vllm
self.comm = Communicator(self.coor) # hardware vendor can use `coor` to initialize their communicator
def init_model(self, args):
pass
def run_model(self, args):
pass
The emphasis of this RFC is to disentangle control-plane communication (Coordinator) and data-plane communication (Communicator). Both of them can have a Communicators typically need control-plane communication to set up the state before large chunk of data communication, e.g. in the design of |
This is awesome! |
In this case, do we decouple communicator implementation from coordinator broadcast implementation? For example, let's say we have a new communicator implementation called Gccl, I am worried it is a leak abstraction if coordinator should use communicator's broadcast implementation under the hood. In that case, it may make more sense the coordinator accepts the communicator as an input? class DistributedWorker:
def __init__(self, args):
self.comm = Communicator() # hardware vendor can use `coor` to initialize their communicator
self.coor = Coordinator(self.comm); # initialize Coordinator, done by vllm
class Coordinator:
def broadcast(self):
self.comm.broadcast(...) |
Yes, they are decoupled. The reference implementation of The dependency chain is: |
Gotcha, that makes sense! Thanks for the clarification! |
QQ: How does the custom all reduce backend fit into this abstraction? |
I'm not familiar with custom allreduce kernel. At the first glance, I feel it can be another implementation of |
Progress tracker:
|
Summary of all the distributed inference cases we need to support: a cartesian product of Precondition: each vllm engine will require at least one GPUs. GPUs cannot be shared across vllm engines. Here are all the possible combinations: Single vllm Engine, Single NodeUsers can use Single vllm Engine, Multiple NodesCurrently only support ray-managed cluster. Only node need to execute this:
Then ray will use all nodes in the cluster. In the future we may need to support torchrun/mpi style launcher, as described in #3902 (comment) .
Multiple vllm Engines, Single NodeUsers can use first engine:
second engine:
Multiple vllm Engines, Multiple NodesIn this case, each vllm engine should have its own nodes. No overlap is allowed. We don't need to do anything in this case. ConclusionIdeally, we should respect |
I think some companies are building their own XXCCL solution to
While this is separate from the vllm engine itself, I am thinking maybe it's worth to extend to this area or use existing library to simplify the work if there's any. |
I think the only use case here is the single node multiple GPU case. In this case, could different threads just set device in this way |
I didn't get this part.
|
|
I am thinking whether the motivation is strong enough. Seems two pain points are 1.
from the architecture level, split communication collective usage into control plane and data plane totally makes sense. I am just thinking whether it's possible to reuse one progress group which is cleaner but address issues you listed. |
First of all, thank you for your interest and feedback!
Process group is just a way to organize processes, it does not create new processes. For tensor parallel of size
That could be the case, but the progress can be really slow. e.g. we have to wait for at least several months for pytorch release. Meanwhile, writing wrappers in vllm is much faster to implement, e.g. we can do it in days. |
Yeah, the communication thread it is cheap. I was thinking about the debuggability and ports etc. (we have limited ports can be exposed in our env). But anyway, consider just two groups needed here, I think my concern was unnecessary. I am buying into the idea. |
Finished in #5293 |
This RFC describes a proposal for interfaces and abstractions for distributed inference environments. I plan to solicit discussions for a week (until March 31st) before I begin to actually refactor the code.
Motivation
The current distributed inference environment in
vllm
is quite tangled, and we often see deadlocks and hangs (see #3455 , #2770 , #3559 , to name a few). The problem becomes prominent when we try to upgrade to pytorch 2.2.0 (see #3442 , #3442 ), becausepytorch 2.2.0
upgrades fromnccl==2.18.1
to2.19.3
(see https://pypi.org/pypi/torch/2.1.2/json and https://pypi.org/pypi/torch/2.2.0/json to compare the dependency), andnccl==2.19.3
breaksvllm
due to increased memory cost during cudagraph capture (from 10MB per graph to 100MB per graph, adds up to several GBs because we have dozens of cudagraph).TL,DR; distributed inference in current codebase is a headache. If it works, hooray; if not, don't be surprised.
Proposal
Abstraction
I think we should have three levels of abstraction:
ray
, but we can also have another choices like Python's nativemultiprocessing
in single-node cases. See [Core] Multiprocessing executor for single-node multi-GPU deployment #3466 for example.filelock
to lock on filesystems ( [Bugfix] use SoftLockFile instead of LockFile #3578 ), useTCP
to initialize communication incupy
( Use CuPy for CUDA graphs #2811 ), useMPI
to initialize communication in AMD'scupy
version ( [ROCm] enable cupy in order to enable cudagraph mode for AMD GPUs #3123 ).nccl
, and AMD also has its own communication library. Note that this is vendor-specific, and vendors usually have their own way of cross-device communication.The most messy one, and the missing one, is the Coordinator abstraction level. More on this later.
Interface
Between each consecutive abstractions, lies the interface.
Interface between Launcher and Coordinator
After Launcher launches processes, it needs to at least tell the processes the following information:
launch_id
, used to distinguish current launch with possibly concurrent launch (e.g. when 4 users want to set up 4 inference engines in the same node, each with 2 GPUs). Note: thelaunch_id
can be used as a "random seed" to draw values formaster_port
, instead of keeping only one defaultmaster_port
value and having to kill all processes after the last run crashes. A reference implementation would be hashing thelaunch_id
to a port number, and increasing the port number to find the first free port. This is a strategy taken by Jupyter Notebook/Lab Server .world_size
, number of processes participating in the current launch (may span over multiple nodes)local_world_size
, number of processes participating in the current launch in the current node (not necessarily the same across nodes)rank
, range from 0 (inclusive) toworld_size
(exclusive) , unique in the launch for each processlocal_rank
, range from 0 (inclusive) tolocal_world_size
(exclusive), unique in each node, can use this to assign devices in a node!master_addr
, the IP address of the master node, should be reachable from all nodesmaster_port
, a free port in the master node, reserved for possible coordinationHow does Launcher pass these information to each process? Basically we have two choices:
Interface between Coordinator and Communicator
Device communicators (e.g.
nccl
) often need to initialize the communication by sharing some unique token (seenccl
documentation). In addition, processes sometimes need to coordinate the resource in a node or across the cluster.In sight of the above consideration,
Coordinator
should at least have the following interfaces:is_master()
: tell if the current process is a master process, i.e. convenient wrapper for boilerplate coderank == 0
is_local_master()
: tell if the current process is a local master process, i.e. convenient wrapper for boilerplate codelocal_rank == 0
broadcast(bytes, src)
: broadcast some message (in the form ofbytes
) from ranksrc
to all the processes. The semantic is standard, no need for more explanation.barrier()
: block until all processes reaches here. Also standard communication primitive.Note: very often than not, we want to execute something in just one process per node (e.g. creating directories, downloading files to the node). Inspired by this thread, we can write code like this:
Furthermore, there are more complicated requirements like "only one process in each node does something, but this something is different across nodes", essentially the requirement of
local_barrier()
, a function that block until all processes in the current node reaches here. It is debatable if we want this (currently I don't see any requirements like this invllm
.)Communicator interface
The following functionality of communicator is suggested (mostly taken from the
nccl
design):allreduce(char* input, size_t count, size_t dtype, size_t op)
. More functionality would be better (e.g. out-of-place allreduce, broadcast/reduce/scatter etc.), but inplace allreduce is all we need currently.The intended usage would be something like this:
A reference implementation of Coordinator
A reference implementation of Coordinator can be
torch.distributed
, with thegloo
backend designed to communicate CPU tensors.Other considerations include MPI and custom-implemented TCP store. However, since we live in
torch
framework,torch.distributed
is a natural choice without any new dependency.Note:
torch.distributed
can also be used as a fully functional communicator for GPU devices. However,torch.distributed.all_reduce
is way more complicated than just an allreduce operation. It might initialize autograd engine, might keep track of gradients, might dispatch to different device kernels. Even if we are intorch.inference_mode
, itsc10
engine might perform some additional operations that fails functionalities like cudagraph. Therefore, I prefer to call vendor-provided communication libraries directly to bypass the problem. After all, we just want an allreduce operation on dense tensors, without any hustle and bustle.Benefits
After we have the above abstraction and interface, we can have the following benefits:
torch.Tensor
(only forward computation ops are enough), a c library (an .so file would be enough) for calling communication ops with raw data (i.e.char*
in c). And if they want to move quickly, just oneallreduce
op would be enough for inference. No need to wait for the whole functionality completed within pytorch.Things not to be considered
We don't aim for a fully-fledged distributed execution environment. And since inference tasks are almost stateless, we don't need to consider elasticness and fault-tolerance. As opposed to training, we don't need to save checkpoints, we don't need to resume from previous failure ...
The text was updated successfully, but these errors were encountered: