-
Notifications
You must be signed in to change notification settings - Fork 9.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : support Jamba hybrid Transformer-Mamba models #7531
base: master
Are you sure you want to change the base?
Conversation
This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.
* llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot
llama.cpp
Outdated
switch (hparams.n_layer) { | ||
// TODO: Jamba layers are a bit heterogenous, so naming this is hard. | ||
case 12: // 900M 8x???M | ||
case 32: // 51B 16x?B | ||
default: model.type = e_model::MODEL_UNKNOWN; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what model size type(s) I should give to Jamba models.
Great job! Works for me too, it's very fast. There were some warnings during compilation, but nothing major.
|
Amazing work!
|
ggml.c
Outdated
if (n_rs > 1) { | ||
// multiple sequences means it's hard to know when it's the first time a state is read, | ||
// so copy them all over to the destination, just to be sure. | ||
for (int i3 = 0; i3 < n_kv; ++i3) { | ||
for (int i3 = 0; i3 < n_rs; ++i3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm looking at adding the missing Metal kernels for SSM_CONV
and SSM_SCAN
. I'm wondering if this part of the kernels where we copy src0
-> dst
could be extracted outside of the operation via ggml_cpy
+ ggml_view
or ggml_acc
? Would simplify the implementation
Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv
, ggml_mul_mat
, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if this part of the kernels where we copy
src0
->dst
could be extracted outside of the operation viaggml_cpy
+ggml_view
orggml_acc
? Would simplify the implementation
Yes, this is definitely possible. I'll find a way to extract the copies outside.
if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.
For SSM_SCAN
, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS
and EXP
would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.
For simplifying SSM_CONV
, I don't think ggml_conv
supports working on independent 1D rolling windows with varying sequence lengths.
When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv
with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):
Lines 6973 to 6982 in 64fbce0
// prepare convolution for all tokens in the batch with a self-overlapping view, | |
// shifting by one column each ... depth? ... with a window of d_conv columns. | |
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok} | |
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0); | |
// perform convolution | |
// => {1, d_inner, n_tok} | |
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d)); | |
// => {d_inner, n_tok, 1} | |
x = ggml_permute(ctx0, x, 2, 0, 1, 3); |
Setting nb[2]
to the element size makes the view self-overlapping.
But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the n_rs
dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch
The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).
Just throwing some thoughts that I have so far - will continue looking at the PR in the next days
Edit: I was writing this comment before I saw you posted - will take a look tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One idea that we might consider is to unfuse the
n_rs
dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch
Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)
Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
The recurrent steps are simpler for ubatches with sequence lengths of 1
, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.
But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.
For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows
? Not sure about the transposed V cache, though.
A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's split when making ubatches, then the number of compute graph nodes can stay constant
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.
For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.
Looking forward to this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance
It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.
Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?
This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.
Let me illustrate.
Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.
0: ################
1: #######
2: #
3: #
Splitting that into equal-length sequences would make 3 ubatches, like so:
0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########
Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.
But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.
From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.
Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍
The change is quite big and I'm having a bit of trouble to merge it all at once. Wonder if we should take a more step-by-step approach. The |
I agree that this is quite big. Sorry about that. I'll see what I can do.
Unfortunately, the I think I might be able to separate some parts of this PR. These are the main separable parts:
|
Could we extend this point a bit more and add support for OpenELM together with it? The PR for OpenELM is almost ready, but has some quick hacks that seem relevant to this: #7359 |
Now that variable GQA support is in (for some context, this allows splitting batches as described in #7531 (comment), and also single-sequence ubatches, as well as the current simple split used on |
Any updates on this since Jamba 1.5 is now out? |
Basically, since #8526 was merged, now I need to resolve a very big merge conflict because I didn't keep the code identical. This will probably take a few days. |
Some progress update on Jamba: I began resolving the merge conflicts, and there were at least 2000+ lines of conflicts (basically half of this PR). This is manageable. While I've solved most of them, the result is not usable (and it doesn't build, and so I did not push it here yet (sorry), I will push once it works) because of the state saving and restoring code which was changed in #8699, and this doesn't yet handle two caches. My problem right now is with the single-sequence session restoring, which uses (EDIT: on further thought To make single-sequence session restores simpler, I could either
The least bad option (EDIT: apart from simply using But I'm starting to think that maybe state checkpoints add too much complexity. The current implementation uses a unified pool of recurrent state cells to allocate checkpoints and/or current states for each Some alternatives:
Manual checkpoint management seems tempting, but would offload the complexity to Meanwhile I will attempt to refactor KV cache defragmentation soon (which should be useful anyway). |
Regarding the manual checkpoint management - recently, the commonly used APIs in the cloud (e.g. Anthropic, OpenAI, etc) introduced "prompt caching" [0], which adds a "cache control" parameter to the requests. It can be used to cache prompts, but I guess it fits well with the idea of manual recurrent state checkpointing from the user code. I'm thinking that the changes for Jamba should be kept to a minimum for now, even if this would require longer processing times for common use cases. The reason is that the architecture is not yet well adopted, so increasing the complexity of the codebase to support it is not very justified. The better approach would be to improve the support for the existing transformer and mamba arches, by refactoring the KV cache and state management implementation and adding unit tests. I suppose a large part of the complexity with Jamba comes from the fact that we are trying to fit the logic into the existing KV cache implementation, which is not well-fit for this architecture. [0] - https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching |
This also slightly reduces the diff from the master branch
For the first time state saving and reloading works for Jamba (both for the whole state and single-sequences). 🎉 This is implemented in fcb889c
Agreed. I'll start simplifying the code and will think about how to best approach manual/explicit checkpoints for a future PR. The implicit checkpoints implemented here are a bit over-engineered, and do not fit in the idea of a "minimal" change.
I did not find the existing KV cache implementation to be particularly limiting. Most of the complexity in the Jamba implementation here comes from the allocation of recurrent states and implicit checkpoints. The only necessary complexity needed for Jamba is that both the KV cache and the recurrent state cache should be kept in sync, and even then most of the complexity is in keeping the metadata of the tree of sequences consistent (some of which is only there to allow fairly allocating the cache between seq_ids). My plan for this PR in the next days/weeks:
What will be left intact is:
As before, hybrid models of different architectures should be able to work on top of that (like how RWKV-v6 and Mamba can share the same recurrent state management code), as long as it's about hybrids between Attention and some recurrent block. This will mean models like Zamba (Mamba + Attention), Zamba2 (Mamba-2 + Attention), RecurrentGemma (RG-LRU + Attention), and others should be easier to implement without worrying about the KV cache API too much, and they will benefit from future improvements in state checkpoint management. (Note that mixing different recurrent architectures in the same model is out of scope, but I don't think this will be a problem) |
- added support for MiniCPM3, RWKVv6, OLMoE, IBM Granite, and Jamba (conversion only: ggerganov/llama.cpp#7531) - update gguf library from upstream
How's this going? |
This adds support for Jamba (fixes #6372). (https://arxiv.org/abs/2403.19887)
To complement
llama_kv_cache
, I propose to addllama_rs_cache
, as well as a top-levelllama_past
to more easily manage both at once.The current implementation of recurrent states (initially written for Mamba) re-uses the tensors allocated for the KV cache to store its recurrent states. Obviously, when both Attention and recurrent states are used at the same time, this previous approach does not work.
Note that since this uses some of the same operators as Mamba, this is CPU-only for now. (see #6758)
API changes
Most of the changes are backward-compatible, but the
llama_kv_cache_seq_rm
andllama_kv_cache_seq_cp
functions have been renamed and now return the token position of the next token after the end of the sequence(s) they affect.This is necessary to properly handle recurrent state checkpoints with models that also use the KV cache (like Jamba (and eventually Griffin)), in case the last valid state doesn't line up with the requested removal when using, for example,
llama_past_seq_rm
.llama_kv_cache_*
functions to rename them tollama_past_*
._kv_
from the names could make them less confusing when working with pure or mixed recurrent modelsllama_past_*
might be a bit too different from the previous name.llama_kv_cache_seq_rm
. It would also be confusing to figure out at a glance which functions are specific to the KV cache and which are specific the the recurrent state cache.llama_past_seq_rm
andllama_past_seq_cp
now returnn_past
, which is the number of previous tokens in the affected sequence (or it can also be interpreted as the next token position at end of the sequence).-1
is passed to bothp0
andp1
)llama_past_seq_max_pos
returns-1
when there are no cells matching the specifiedseq_id
, to allow calculatingn_past
by adding one to its result.llama_kv_cache_seq_max_pos
previously returned0
in this case, which makes it indistinguishable from a when there's a single cell with pos 0New features
llama.cpp
n_parallel
is at least 3 or 4 times the number of actual usersserver
example when trimming the stop string{model}.attention.head_count_kv
can now also be an array ofint32_t
, one value per layer0
kv heads are considered recurrent layers (Mamba, in the case of Jamba).Internal changes
llama_rs_cache
, a ring-buffered tree of recurrent statesllama_cache
which contains bothllama_kv_cache
andllama_rs_cache
ggml_ssm_*
operators)llama_past_seq_cp
doesn't use more RS cells the more sequences there arellama_ubatch
for more metadata about sequencesllama_get_logits
to match the old expected output. This is not a problem withllama_get_logits_ith
, because there was already an indirection withlctx.output_ids
which is reused.TODO
llama_past_*
Anybody has better name suggestions?llama_cache_*
llama_past_*
llama_kv_cache_*
llama_ctx_cache_*
llama_llm_cache_*
llama_seq_cache_*
llama_tok_cache_*
llama_comp_cache_*
llama_past_cache_*
llama_work_cache_*
llama_kvrs_cache_*
llama_causal_cache_*
llama_context_cache_*
n_past
from thellama_past_*
functions used in the various examplesserver
,main
speculative
,lookup
,lookahead
tests/test-llama-past.cpp
)Future ideas
--parallel
to a big value while not unnecessarily limiting the context size of the clients of theserver
if there aren't many. (related to Parallelization / Batching Explanation #4130 (reply in thread))Testing
convert-hf-to-gguf.py
)main
)server
with backtrackingparallel
bge-small
, gives exactly the same embeddings as onmaster
, and now it doesn't unnecessarily allocate the KV cache!Example output of
jamba-900M-v0.13-KIx2
(click to expand)