Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support Mamba Selective State Space Models #5328

Merged
merged 43 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8cd0a28
mamba : begin working on support for Mamba SSM
compilade Jan 26, 2024
5a69a26
mamba : begin figuring out how to (ab)use the kv cache for Mamba
compilade Jan 27, 2024
f680364
mamba : recurrent inference almost works, but incoherent
compilade Jan 28, 2024
54d3e48
mamba : recurrent inference WORKS!!!
compilade Jan 28, 2024
74eea85
convert : optionally use d_conv and d_state from config.json for Mamba
compilade Jan 29, 2024
9e77061
mamba : refactor recurrent conv, resulting in 20% perf increase
compilade Jan 29, 2024
3f7233b
ggml : parallelize ggml_exp
compilade Jan 29, 2024
e9cc45e
mamba : simplify the conv step with a self-overlapping view
compilade Jan 31, 2024
81b57bb
mamba : fix self-overlapping view depth stride
compilade Jan 31, 2024
ffc116f
mamba : handle batches of more than 1 token
compilade Feb 1, 2024
78a853b
ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation
compilade Feb 2, 2024
5816ae6
mamba : very basic quantization support
compilade Feb 2, 2024
a3f4a1c
mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator
compilade Feb 3, 2024
9f55809
convert : for Mamba, also consider the "MambaLMHeadModel" arch name
compilade Feb 4, 2024
cd0f33f
mamba : fix vocab size problems with official models
compilade Feb 4, 2024
de92f15
ggml : remove ggml_exp and ggml_soft_plus
compilade Feb 4, 2024
766db75
mamba : remove some useless comments
compilade Feb 4, 2024
c52fb3c
convert : fix flake8 linter errors
compilade Feb 5, 2024
6ff34da
mamba : apply suggestions from code review
compilade Feb 5, 2024
8a43ffc
mamba : multiple sequences, but one at a time
compilade Feb 14, 2024
e73eaa7
mamba : in comments, properly refer to KV cells instead of slots
compilade Feb 14, 2024
de50c54
mamba : reduce memory usage of ggml_ssm_scan
compilade Feb 18, 2024
9473ec2
mamba : simultaneous sequence processing
compilade Feb 19, 2024
3dcf798
mamba : support llama_kv_cache_seq_cp copy chains
compilade Feb 25, 2024
34e2fca
mamba : make the server and parallel examples work with whole sequences
compilade Feb 25, 2024
79d636c
mamba : dedicate an input tensor for state copy indices
compilade Feb 25, 2024
8f605cf
mamba : adapt perplexity, batched, and batched-bench examples
compilade Feb 27, 2024
206e8ee
mamba : stop abusing attention metadata
compilade Feb 28, 2024
1af1000
mamba : more correctly update the "used" field of the KV cache
compilade Mar 2, 2024
d52dd50
ggml : in ggml_ssm_scan, use a threshold for soft_plus
compilade Mar 3, 2024
b83fbc9
convert : for Mamba, fallback to internal NeoX tokenizer
compilade Mar 3, 2024
eefb794
mamba : support state saving and restoring
compilade Mar 3, 2024
2a99d1b
ggml : implicitly pass src tensors through dst for Mamba-related ops
compilade Mar 4, 2024
93fd4b8
mamba : clarify some comments
compilade Mar 4, 2024
5544f52
Merge branch 'master' into support-mamba-ssm
compilade Mar 5, 2024
916b586
Merge branch 'master' into support-mamba-ssm
compilade Mar 7, 2024
7cd5a1f
server : fix cache_tokens not getting correctly resized
compilade Mar 7, 2024
d8024a4
convert-hf : support new metadata keys for Mamba
compilade Mar 8, 2024
17e4d6c
mamba : rename metadata to be more similar to transformers library
compilade Mar 8, 2024
1c8ea55
mamba : add missing spaces
compilade Mar 8, 2024
d0d32dc
convert-hf : omit output.weight when identical with token_embd.weight
compilade Mar 8, 2024
3e5685f
readme : add Mamba to supported models, and add recent API changes
compilade Mar 8, 2024
39579d3
mamba : move state_seq and state_mask views outside layer loop
compilade Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mamba : simplify the conv step with a self-overlapping view
Turns out the conv_state can be made smaller by one column.
Note that this breaks existing GGUFs of Mamba,
because the key_value_length field is tied to the conv_state size.

Convolution with a self-overlapping view is cool!
And it's much simpler than what I initially thought would be necessary
to make the convolution step work with more than 1 token at a time.

Next step is to make the SSM step work on batches of tokens too,
and thus I need to figure out a way to make a parallel selective scan
which will keep the ssm_state small and won't make it bigger
by a factor of (n_layer * batch_size).

* llama : fix Mamba KV self size wrongly displaying as f16 instead of f32

Relatedly, I also tried to see if other types than f32 worked for the states,
but they don't, because of the operators used.
It's probably better anyway to keep lots of precision there,
since the states are small anyway.
  • Loading branch information
compilade committed Mar 3, 2024
commit e9cc45ecae696e8f1fa15b8f355b9c2e1f984f80
6 changes: 4 additions & 2 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,10 +1858,12 @@ def set_gguf_parameters(self):
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(d_inner)
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4))
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
# Since the first column of the conv_state is shifted out each time, it's not actually needed
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
self.gguf_writer.add_file_type(self.ftype)

Expand Down
104 changes: 59 additions & 45 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2069,9 +2069,6 @@ static bool llama_kv_cache_init(
if (model.arch == LLM_ARCH_MAMBA) {
// only one slot is needed for Mamba
n_ctx = 1;
// it's probably best to keep as much precision as possible for the states
ktype = GGML_TYPE_F32;
vtype = GGML_TYPE_F32;
}

cache.has_shift = false;
Expand Down Expand Up @@ -4681,7 +4678,7 @@ static bool llm_load_tensors(
} break;
case LLM_ARCH_MAMBA:
{
const int64_t d_conv = hparams.n_embd_head_k;
const int64_t d_conv = hparams.n_embd_head_k + 1;
const int64_t d_state = hparams.n_embd_head_v;
const int64_t d_inner = hparams.n_head;
// FIXME: ceiling instead of floor
Expand Down Expand Up @@ -7915,28 +7912,27 @@ struct llm_build_context {
struct ggml_cgraph * build_mamba() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

const bool use_conv = batch.n_tokens > 1;
GGML_ASSERT(use_conv == false); // TODO: implement
const int32_t n_tok = batch.n_tokens;

// hopefully the compiler does constant folding
const int64_t d_model = n_embd;
const int64_t d_inner = n_head;
GGML_ASSERT(2 * d_model == d_inner);
const int64_t d_conv = n_embd_head_k;
const int64_t d_conv = n_embd_head_k + 1;
const int64_t d_state = n_embd_head_v;
const int64_t dt_rank = d_model / 16;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

// NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
// {n_embd, batch}
// {n_embd, n_tok}
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);

for (int il = 0; il < n_layer; ++il) {
// (ab)using the kv cache to store the state
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner);
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);

// norm
Expand All @@ -7945,33 +7941,43 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);

// {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch}
// {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok}
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
// split the above in two
// assuming it's contiguous
// {d_inner, batch}
// => {d_inner, n_tok}
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);

cur = x;

// conv
{
// shift conv state left
conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);

// update last column
// x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1));

ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));

// rearrange and sum
// no need to rearrange the conv_state, since it's already in the right shape
// => {1, d_inner}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
// => {d_inner, 1}
x = ggml_transpose(ctx0, x);
// concat last (d_conv - 1) columns of conv_state, and x

// The following tensor is too big in order to avoid an assertion error when making an overlapping view.
// TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
// This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
// which is around (d_conv-1) times as small as its current size.
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);

conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
// unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));

// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
ggml_view_tensor(ctx0, kv_self.k_l[il])));

// prepare convolution for all tokens in the batch with a self-overlapping view
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1)*d_inner*ggml_element_size(conv_x), 0);

// perform convolution
// => {1, d_inner, n_tok}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
// => {d_inner, n_tok, 1}
x = ggml_permute(ctx0, x, 2, 0, 1, 3);

// bias
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
Expand All @@ -7981,23 +7987,24 @@ struct llm_build_context {

// ssm
{
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
// FIXME: handle batches of more than 1 token
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0);
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));

// {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b));
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
// split
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));

// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
dt = ggml_soft_plus(ctx0, dt);

// FIXME: support batches with more than 1 token
// => {d_state, d_inner}
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt));
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));

// => {d_state, d_inner}
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt));
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt);

// => {d_state, d_inner}
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
Expand All @@ -8012,7 +8019,7 @@ struct llm_build_context {
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));

// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
// {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
}

Expand Down Expand Up @@ -12327,8 +12334,15 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;

const ggml_type type_k = params.type_k;
const ggml_type type_v = params.type_v;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;

// Mamba (mis)uses the KV cache to store its states
if (model->arch == LLM_ARCH_MAMBA) {
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
}

GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
Expand Down