Skip to content

ggml-backend: backend-agnostic tensor parallelism #13776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 65 commits into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

I'm currently working on support for backend-agnostic tensor parallelism. I've progressed to the point where I have a working prototype (that only works for 2 GPUs and has bad performance). I'm making this PR in order to get early feedback regarding the way I would implement it, input from @slaren in particular would be appreciated. Specifically I would:

  1. Add a backend-agnostic interface for split buffers to ggml-backend.cpp to e.g. check whether a buffer is split, which backends are associated with it if it is, and to retrieve the effective tensor for a given backend. I think this can be done without any backend-specific code. The input would be multiple backend buffers, when allocating a tensor on the split buffer this would be translated to allocating slices of the tensor on the underlying backend buffers.
  2. Refactor the code for ggml_backend_sched to revolve more around splits instead of the nodes from the original graph. Without tensor parallelism there will be effectively no change because the splits just contain all nodes from the original graph in sequential order. So the same results should be achieved by iterating over splits vs. iterating over nodes.
  3. When using tensor parallelism, split the graph at additional points and duplicate the splits in such a way that some operations can run in parallel across multiple backends. The existing code for pipeline parallelism can be re-used to handle the scheduling, data transfer, and synchronization. To combine the results from multiple backends the current solution is to copy the partial results from other backends and to then use GGML_CONCAT to combine them into a tensor that contains the correct data. For this I extended the functionality of ggml_backend_sched_split::inputs. Tensors with GGML_OP_NONE use the existing code to retrieve data from other backends. Tensors with other ops are executed prior to the actual nodes from the split.
  4. Extend the logic for split tensors to cover not just a split by dimension 1 but of the other dimensions + mirrored data as well. It will be not just weights that can be split but nodes as well. Define a function similar to the _supports_op functions to determine the state of split tensors after some op given the states of the inputs. If an op cannot be meaningfully executed in parallel, synchronize the nodes as a fallback. This should ensure that correct results can always be produced, but with bad performance if the correct transformation logic is not defined. For the attention I think the graph should be split by dimension 2, for the FFN part I think it should be dimension 1 -> dimension 0 -> mirrored. In total there would need to be 4 synchronizations per layer.

Going forward, since ggml_backend_sched is a critical component I would first make a separate PR to refactor it slightly so that it's easier to assert that no changes are being made for use without tensor parallelism. The approach I have in this PR is to first split the graph and to create a vector of sequential splits splits_no_tp where splits that need tensor parallelism are marked. Then in a second pass a vector splits_tp is created where tensor parallel splits are duplicated. Only after this are inputs being assigned. Finally, the vector splits_tp is copied to ggml_backend_sched::splits. So in effect I have split the 5th pass over the graph nodes into 2 passes where I can duplicate the tensor parallel splits inbetween. I used vectors because it made the implementation the easiest, but it should be possible to do the same thing with one more allocation like ggml_backend_sched::splits that grows dynamically when needed. I assume the reason a vector is not used in the current code for ggml_backend_sched::splits is to assert that the memory is never reallocated when repeatedly changing the number of splits.

For the main PR the goal would be to get an implementation that is at least as fast as the current CUDA code for --split-mode row but does not need code specific to the CUDA backend. This then makes it possible to remove ggml_cuda_op_mul_mat without loss of functionality.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant