Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support Mamba Selective State Space Models #5328

Merged
merged 43 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8cd0a28
mamba : begin working on support for Mamba SSM
compilade Jan 26, 2024
5a69a26
mamba : begin figuring out how to (ab)use the kv cache for Mamba
compilade Jan 27, 2024
f680364
mamba : recurrent inference almost works, but incoherent
compilade Jan 28, 2024
54d3e48
mamba : recurrent inference WORKS!!!
compilade Jan 28, 2024
74eea85
convert : optionally use d_conv and d_state from config.json for Mamba
compilade Jan 29, 2024
9e77061
mamba : refactor recurrent conv, resulting in 20% perf increase
compilade Jan 29, 2024
3f7233b
ggml : parallelize ggml_exp
compilade Jan 29, 2024
e9cc45e
mamba : simplify the conv step with a self-overlapping view
compilade Jan 31, 2024
81b57bb
mamba : fix self-overlapping view depth stride
compilade Jan 31, 2024
ffc116f
mamba : handle batches of more than 1 token
compilade Feb 1, 2024
78a853b
ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation
compilade Feb 2, 2024
5816ae6
mamba : very basic quantization support
compilade Feb 2, 2024
a3f4a1c
mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator
compilade Feb 3, 2024
9f55809
convert : for Mamba, also consider the "MambaLMHeadModel" arch name
compilade Feb 4, 2024
cd0f33f
mamba : fix vocab size problems with official models
compilade Feb 4, 2024
de92f15
ggml : remove ggml_exp and ggml_soft_plus
compilade Feb 4, 2024
766db75
mamba : remove some useless comments
compilade Feb 4, 2024
c52fb3c
convert : fix flake8 linter errors
compilade Feb 5, 2024
6ff34da
mamba : apply suggestions from code review
compilade Feb 5, 2024
8a43ffc
mamba : multiple sequences, but one at a time
compilade Feb 14, 2024
e73eaa7
mamba : in comments, properly refer to KV cells instead of slots
compilade Feb 14, 2024
de50c54
mamba : reduce memory usage of ggml_ssm_scan
compilade Feb 18, 2024
9473ec2
mamba : simultaneous sequence processing
compilade Feb 19, 2024
3dcf798
mamba : support llama_kv_cache_seq_cp copy chains
compilade Feb 25, 2024
34e2fca
mamba : make the server and parallel examples work with whole sequences
compilade Feb 25, 2024
79d636c
mamba : dedicate an input tensor for state copy indices
compilade Feb 25, 2024
8f605cf
mamba : adapt perplexity, batched, and batched-bench examples
compilade Feb 27, 2024
206e8ee
mamba : stop abusing attention metadata
compilade Feb 28, 2024
1af1000
mamba : more correctly update the "used" field of the KV cache
compilade Mar 2, 2024
d52dd50
ggml : in ggml_ssm_scan, use a threshold for soft_plus
compilade Mar 3, 2024
b83fbc9
convert : for Mamba, fallback to internal NeoX tokenizer
compilade Mar 3, 2024
eefb794
mamba : support state saving and restoring
compilade Mar 3, 2024
2a99d1b
ggml : implicitly pass src tensors through dst for Mamba-related ops
compilade Mar 4, 2024
93fd4b8
mamba : clarify some comments
compilade Mar 4, 2024
5544f52
Merge branch 'master' into support-mamba-ssm
compilade Mar 5, 2024
916b586
Merge branch 'master' into support-mamba-ssm
compilade Mar 7, 2024
7cd5a1f
server : fix cache_tokens not getting correctly resized
compilade Mar 7, 2024
d8024a4
convert-hf : support new metadata keys for Mamba
compilade Mar 8, 2024
17e4d6c
mamba : rename metadata to be more similar to transformers library
compilade Mar 8, 2024
1c8ea55
mamba : add missing spaces
compilade Mar 8, 2024
d0d32dc
convert-hf : omit output.weight when identical with token_embd.weight
compilade Mar 8, 2024
3e5685f
readme : add Mamba to supported models, and add recent API changes
compilade Mar 8, 2024
39579d3
mamba : move state_seq and state_mask views outside layer loop
compilade Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator
This increases performance on CPU by around 30% for prompt processing,
and by around 20% for text generation.

However, it also makes the ggml_exp and ggml_soft_plus operators unused.
Whether or not they should be kept will be decided later.
  • Loading branch information
compilade committed Mar 3, 2024
commit a3f4a1c7dc9fc10082d5290b49505bc3d3db239c
113 changes: 77 additions & 36 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6156,31 +6156,45 @@ struct ggml_tensor * ggml_flash_attn_back(
struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
struct ggml_tensor * s,
struct ggml_tensor * dA,
struct ggml_tensor * dB_x) {
GGML_ASSERT(ggml_are_same_shape(dA, dB_x));
struct ggml_tensor * x,
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B) {
GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
ggml_are_same_shape(x, dt);
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
compilade marked this conversation as resolved.
Show resolved Hide resolved

GGML_ASSERT( s->nb[0] == ggml_type_size( s->type));
GGML_ASSERT( dA->nb[0] == ggml_type_size( dA->type));
GGML_ASSERT(dB_x->nb[0] == ggml_type_size(dB_x->type));
{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_tok = x->ne[1];

GGML_ASSERT(s->ne[0] == dA->ne[0]);
GGML_ASSERT(s->ne[1] == dA->ne[1]);
GGML_ASSERT(s->ne[2] == 1 && s->ne[3] == 1); // the ssm_state should be 2D
GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_tok);
}

bool is_node = false;

if (s->grad || dA->grad || dB_x->grad) {
if (s->grad || x->grad || dt->grad || A->grad || B->grad) {
is_node = true;
}

struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, dA->ne[0], dA->ne[1], dA->ne[2]);
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, s->ne[0], s->ne[1], x->ne[1]);

result->op = GGML_OP_SSM_SCAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = s;
result->src[1] = dA;
result->src[2] = dB_x;
result->src[1] = x;
result->src[2] = dt;
result->src[3] = A;
result->src[4] = B;

return result;
}
Expand Down Expand Up @@ -14795,9 +14809,11 @@ static void ggml_compute_forward_flash_attn_back(

static void ggml_compute_forward_ssm_scan_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src0, // s
const struct ggml_tensor * src1, // x
const struct ggml_tensor * src2, // dt
const struct ggml_tensor * src3, // A
const struct ggml_tensor * src4, // B
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
return;
Expand All @@ -14806,18 +14822,19 @@ static void ggml_compute_forward_ssm_scan_f32(
const int ith = params->ith;
const int nth = params->nth;

const int64_t nc = src1->ne[0];
const int64_t n_t = src1->ne[2]; // number of tokens in the batch
const int64_t nc = src0->ne[0];
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
const int64_t nr0 = ggml_nrows(src0);

GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(src1));
GGML_ASSERT(nc*n_t*nr0 == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
// allow merging multiple rows in the same vec operation
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
GGML_ASSERT(src1->nb[1] == src1->ne[0]*sizeof(float));
GGML_ASSERT(src2->nb[1] == src2->ne[0]*sizeof(float));
GGML_ASSERT(src3->nb[1] == src3->ne[0]*sizeof(float));

// rows per thread
const int dr = (nr0 + nth - 1)/nth;
Expand All @@ -14829,22 +14846,44 @@ static void ggml_compute_forward_ssm_scan_f32(

// first batch
{
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]));
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1]));
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]));
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]));
ggml_vec_mul_f32(nc*ir, dest, s, dA);
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
compilade marked this conversation as resolved.
Show resolved Hide resolved
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
float x_dt = x[i1] * dt_soft_plus;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}

// compute state for rest of tokens, previous state comes from dest
for (int i2 = 1; i2 < n_t; i2++) {
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2]));
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2]));
float * dA = (float *) ((char *) src1->data + ir0*(src1->nb[1]) + i2 *(src1->nb[2]));
float * dB_x = (float *) ((char *) src2->data + ir0*(src2->nb[1]) + i2 *(src2->nb[2]));
ggml_vec_mul_f32(nc*ir, dest, s, dA);
ggml_vec_add_f32(nc*ir, dest, dest, dB_x);
for (int i2 = 1; i2 < n_t; ++i2) {
float * dest = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = log1pf(expf(dt[i1]));
float x_dt = x[i1] * dt_soft_plus;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// ssm_state * dA + dB * x
dest[i] = s[i]*(expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
}
}
}
}

Expand All @@ -14853,11 +14892,13 @@ static void ggml_compute_forward_ssm_scan(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * src2,
const struct ggml_tensor * src3,
const struct ggml_tensor * src4,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, dst);
ggml_compute_forward_ssm_scan_f32(params, src0, src1, src2, src3, src4, dst);
} break;
default:
{
Expand Down Expand Up @@ -15927,7 +15968,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_SSM_SCAN:
{
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
ggml_compute_forward_ssm_scan(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
} break;
case GGML_OP_WIN_PART:
{
Expand Down
6 changes: 4 additions & 2 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1724,8 +1724,10 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_ssm_scan(
struct ggml_context * ctx,
struct ggml_tensor * s,
struct ggml_tensor * dA,
struct ggml_tensor * dB_x);
struct ggml_tensor * x,
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B);

// partition into non-overlapping windows with padding if needed
// example:
Expand Down
49 changes: 6 additions & 43 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7999,55 +7999,18 @@ struct llm_build_context {
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
// split
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tok, x_db->nb[1], 0);
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tok, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));

// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
dt = ggml_soft_plus(ctx0, dt);

struct ggml_tensor * dA;
struct ggml_tensor * dB;
if (n_tok == 1) {
// => {d_state, d_inner}
dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));

// {d_state} * {d_inner} => {d_state, d_inner}
dB = ggml_out_prod(ctx0, B, dt);
} else {
// {d_state, d_inner} * {d_inner, n_tok} => {d_state, d_inner, n_tok} * {1, d_inner, n_tok}
// => {d_state, d_inner, n_tok}
// Trying to do the equivalent of
// dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
struct ggml_tensor * A = model.layers[il].ssm_a;
dA = ggml_exp(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, A, ggml_new_tensor_3d(ctx0, A->type, d_state, d_inner, n_tok)),
// {d_inner, n_tok} => {1, d_inner, n_tok}
ggml_permute(ctx0, dt, 1, 2, 0, 3))
);

// {d_state, 1, n_tok} * {d_inner, 1, n_tok} => {d_state, d_inner, n_tok}
dB = ggml_out_prod(ctx0,
// {d_state, n_tok} => {d_state, 1, n_tok}
ggml_permute(ctx0, B, 0, 2, 1, 3),
// {d_state, n_tok} => {d_state, 1, n_tok}
ggml_permute(ctx0, dt, 0, 2, 1, 3));
}

// {d_state, d_inner, n_tok} * {1, d_inner, n_tok} => {d_state, d_inner, n_tok}
cur = ggml_mul(ctx0, dB, ggml_permute(ctx0, x, 1, 2, 0, 3));

// The selective scan seems inherently sequential...
// To avoid making (n_layer * n_tok) graph nodes, let's use a custom operator.
// When n_tok == 1, it's equivalent to the following:
// ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
// When n_tok is bigger, it's the same thing, but iterated n_tok times,
// with the correct dA and cur for each token.
// The resulting states are layered on the ne[2] dimension.
// Custom operator to implement some of the optimizations
// described in the Annex D of the Mamba paper.
// TODO: maybe also optimize step 4 of the Speed section of Annex D (the mul_mat with C)
// => {d_state, d_inner, n_tok}
ssm_state = ggml_ssm_scan(ctx0, ssm_state, dA, cur);
ssm_state = ggml_ssm_scan(ctx0, ssm_state, x, dt, model.layers[il].ssm_a, B);

// only store last state
ggml_build_forward_expand(gf,
Expand Down