Skip to content

Commit c3564d8

Browse files
committed
llama: rwkv6: Use the new advanced batch splits
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
1 parent 9ffa40d commit c3564d8

File tree

3 files changed

+66
-204
lines changed

3 files changed

+66
-204
lines changed

ggml/include/ggml.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,6 @@ extern "C" {
509509
GGML_OP_GET_REL_POS,
510510
GGML_OP_ADD_REL_POS,
511511
GGML_OP_RWKV_WKV,
512-
GGML_OP_RWKV_TOKEN_SHIFT,
513512

514513
GGML_OP_UNARY,
515514

@@ -1857,14 +1856,7 @@ extern "C" {
18571856
struct ggml_tensor * r,
18581857
struct ggml_tensor * tf,
18591858
struct ggml_tensor * td,
1860-
struct ggml_tensor * state,
1861-
struct ggml_tensor * state_seq);
1862-
1863-
GGML_API struct ggml_tensor * ggml_rwkv_token_shift(
1864-
struct ggml_context * ctx,
1865-
struct ggml_tensor * x_carry,
1866-
struct ggml_tensor * x_norm,
1867-
struct ggml_tensor * state_seq);
1859+
struct ggml_tensor * state);
18681860

18691861
// custom operators
18701862

ggml/src/ggml.c

Lines changed: 11 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28172817
"GET_REL_POS",
28182818
"ADD_REL_POS",
28192819
"RWKV_WKV",
2820-
"RWKV_TOKEN_SHIFT",
28212820

28222821
"UNARY",
28232822

@@ -2836,7 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
28362835
"CROSS_ENTROPY_LOSS_BACK",
28372836
};
28382837

2839-
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2838+
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
28402839

28412840
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
28422841
"none",
@@ -2906,8 +2905,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29062905
"win_unpart(x)",
29072906
"get_rel_pos(x)",
29082907
"add_rel_pos(x)",
2909-
"rwkv_wkv(k, v, r, tf, td, s, sq)",
2910-
"rwkv_token_shift(xc, xn, sq)",
2908+
"rwkv_wkv(k, v, r, tf, td, s)",
29112909

29122910
"unary(x)",
29132911

@@ -2926,7 +2924,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
29262924
"cross_entropy_loss_back(x,y)",
29272925
};
29282926

2929-
static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
2927+
static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75");
29302928

29312929
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
29322930

@@ -7494,39 +7492,36 @@ struct ggml_tensor * ggml_rwkv_wkv(
74947492
struct ggml_tensor * r,
74957493
struct ggml_tensor * tf,
74967494
struct ggml_tensor * td,
7497-
struct ggml_tensor * state,
7498-
struct ggml_tensor * state_seq) {
7495+
struct ggml_tensor * state) {
74997496
GGML_ASSERT(ggml_is_contiguous(k));
75007497
GGML_ASSERT(ggml_is_contiguous(v));
75017498
GGML_ASSERT(ggml_is_contiguous(r));
75027499
GGML_ASSERT(ggml_is_contiguous(tf));
75037500
GGML_ASSERT(ggml_is_contiguous(td));
75047501
GGML_ASSERT(ggml_is_contiguous(state));
7505-
GGML_ASSERT(ggml_is_contiguous(state_seq));
7506-
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
75077502

75087503
const int64_t S = k->ne[0];
75097504
const int64_t H = k->ne[2];
75107505
const int64_t n_tokens = k->ne[3];
7511-
const int64_t n_kv = state_seq->ne[0];
7506+
const int64_t n_seqs = state->ne[1];
75127507
{
75137508
GGML_ASSERT(k->ne[1] == 1);
75147509
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
75157510
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
75167511
// TODO: RWKV v4 and v5
75177512
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7518-
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv);
7513+
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
75197514
}
75207515

75217516
bool is_node = false;
75227517

7523-
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad) {
7518+
if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
75247519
GGML_ABORT("fatal error"); // TODO: implement backward
75257520
is_node = true;
75267521
}
75277522

75287523
// concat output and new_state
7529-
const int64_t ne[4] = { S * H, n_tokens + S * n_kv, 1, 1 };
7524+
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
75307525
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
75317526

75327527
result->op = GGML_OP_RWKV_WKV;
@@ -7537,48 +7532,6 @@ struct ggml_tensor * ggml_rwkv_wkv(
75377532
result->src[3] = tf;
75387533
result->src[4] = td;
75397534
result->src[5] = state;
7540-
result->src[6] = state_seq;
7541-
7542-
return result;
7543-
}
7544-
7545-
// ggml_rwkv_token_shift
7546-
7547-
struct ggml_tensor * ggml_rwkv_token_shift(
7548-
struct ggml_context * ctx,
7549-
struct ggml_tensor * x_carry,
7550-
struct ggml_tensor * x_norm,
7551-
struct ggml_tensor * state_seq) {
7552-
GGML_ASSERT(ggml_is_contiguous(x_carry));
7553-
GGML_ASSERT(ggml_is_contiguous(x_norm));
7554-
GGML_ASSERT(ggml_is_contiguous(state_seq));
7555-
GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7556-
7557-
const int64_t n_embd = x_norm->ne[0];
7558-
const int64_t n_kv = state_seq->ne[0];
7559-
const int64_t n_tokens = state_seq->ne[1];
7560-
{
7561-
GGML_ASSERT(x_norm->ne[0] == n_embd);
7562-
GGML_ASSERT(x_norm->ne[1] == n_tokens);
7563-
GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
7564-
}
7565-
7566-
bool is_node = false;
7567-
7568-
if (x_carry->grad || x_norm->grad || state_seq->grad) {
7569-
GGML_ABORT("fatal error"); // TODO: implement backward
7570-
is_node = true;
7571-
}
7572-
7573-
// concat output and new_state
7574-
const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
7575-
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7576-
7577-
result->op = GGML_OP_RWKV_TOKEN_SHIFT;
7578-
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7579-
result->src[0] = x_carry;
7580-
result->src[1] = x_norm;
7581-
result->src[2] = state_seq;
75827535

75837536
return result;
75847537
}
@@ -16418,7 +16371,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1641816371
const size_t T = dst->src[1]->ne[3];
1641916372
const size_t C = dst->ne[0];
1642016373
const size_t H = dst->src[1]->ne[2];
16421-
const size_t n_kv = dst->src[6]->ne[0];
16374+
const size_t n_seqs = dst->src[5]->ne[1];
1642216375

1642316376
float * dst_data = (float *) dst->data;
1642416377
float * state = ((float *) dst->data) + C * T;
@@ -16434,8 +16387,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1643416387
float * r = (float *) dst->src[2]->data;
1643516388
float * time_faaaa = (float *) dst->src[3]->data;
1643616389
float * time_decay = (float *) dst->src[4]->data;
16437-
int32_t * seq_data = (int32_t *) dst->src[6]->data;
16438-
memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
16390+
memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));
1643916391

1644016392
size_t t_stride = H * (C / H);
1644116393

@@ -16448,7 +16400,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1644816400
// recursive through each token
1644916401
for (size_t t = 0; t < T; t++) {
1645016402
size_t t_offset = t * t_stride;
16451-
float * state_cur = state + (C / H) * C * seq_data[t * n_kv];
16403+
float * state_cur = state + (C / H) * C * (t / (T / n_seqs));
1645216404

1645316405
for (size_t h = 0; h < H; h++) {
1645416406
size_t h_offset = h * h_stride;
@@ -16480,15 +16432,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
1648016432
}
1648116433
}
1648216434
}
16483-
16484-
for (size_t t = 0; t < T; t++) {
16485-
for (size_t kv = 1; kv < n_kv; kv++) {
16486-
int64_t seq = seq_data[t * n_kv + kv];
16487-
if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
16488-
memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
16489-
}
16490-
}
16491-
}
1649216435
}
1649316436

1649416437
static void ggml_compute_forward_rwkv_wkv(
@@ -16509,77 +16452,6 @@ static void ggml_compute_forward_rwkv_wkv(
1650916452
}
1651016453
}
1651116454

16512-
static void ggml_compute_forward_rwkv_token_shift_f32(
16513-
const struct ggml_compute_params * params,
16514-
struct ggml_tensor * dst) {
16515-
const int64_t n_embd = dst->ne[0];
16516-
const int64_t n_kv = dst->src[2]->ne[0];
16517-
const int64_t n_tokens = dst->src[1]->ne[1];
16518-
float * dst_data = (float *) dst->data;
16519-
float * x_carry = (float *) dst->src[0]->data;
16520-
float * x_norm = (float *) dst->src[1]->data;
16521-
int32_t * sq_data = (int32_t *) dst->src[2]->data;
16522-
16523-
if (params->ith != 0) {
16524-
return;
16525-
}
16526-
16527-
int32_t seq_start = 0;
16528-
int32_t seq_length = 0;
16529-
16530-
for (int i1 = 0; i1 < n_kv; ++i1) {
16531-
seq_start = -1;
16532-
// assume that the tokens for each sequence are contiguous
16533-
for (int i2 = 0; i2 < n_tokens; ++i2) {
16534-
int32_t seq = sq_data[i2*n_kv];
16535-
if (seq == i1 && seq_start < 0) {
16536-
seq_start = i2;
16537-
}
16538-
16539-
if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
16540-
seq_length = i2 - seq_start + (i2 == n_tokens - 1);
16541-
break;
16542-
}
16543-
}
16544-
16545-
if (seq_start >= 0) {
16546-
int32_t seq = sq_data[seq_start*n_kv];
16547-
memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
16548-
memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
16549-
}
16550-
}
16551-
16552-
for (int i3 = 0; i3 < n_kv; ++i3) {
16553-
int32_t last_token_pos = 0;
16554-
for (int i4 = 0; i4 < n_tokens; ++i4) {
16555-
for (int i5 = 0; i5 < n_kv; ++i5) {
16556-
if (sq_data[i4*n_kv + i5] == i3) {
16557-
last_token_pos = i4;
16558-
}
16559-
}
16560-
}
16561-
memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
16562-
}
16563-
}
16564-
16565-
static void ggml_compute_forward_rwkv_token_shift(
16566-
const struct ggml_compute_params * params,
16567-
struct ggml_tensor * dst) {
16568-
16569-
const struct ggml_tensor * src0 = dst->src[0];
16570-
16571-
switch (src0->type) {
16572-
case GGML_TYPE_F32:
16573-
{
16574-
ggml_compute_forward_rwkv_token_shift_f32(params, dst);
16575-
} break;
16576-
default:
16577-
{
16578-
GGML_ABORT("fatal error");
16579-
}
16580-
}
16581-
}
16582-
1658316455
// ggml_compute_forward_map_unary
1658416456

1658516457
static void ggml_compute_forward_map_unary_f32(
@@ -17230,10 +17102,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1723017102
{
1723117103
ggml_compute_forward_rwkv_wkv(params, tensor);
1723217104
} break;
17233-
case GGML_OP_RWKV_TOKEN_SHIFT:
17234-
{
17235-
ggml_compute_forward_rwkv_token_shift(params, tensor);
17236-
} break;
1723717105
case GGML_OP_MAP_UNARY:
1723817106
{
1723917107
ggml_unary_op_f32_t fun;
@@ -18305,7 +18173,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1830518173
case GGML_OP_GET_REL_POS:
1830618174
case GGML_OP_ADD_REL_POS:
1830718175
case GGML_OP_RWKV_WKV:
18308-
case GGML_OP_RWKV_TOKEN_SHIFT:
1830918176
case GGML_OP_MAP_UNARY:
1831018177
case GGML_OP_MAP_BINARY:
1831118178
case GGML_OP_MAP_CUSTOM1_F32:
@@ -18876,7 +18743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1887618743
case GGML_OP_WIN_UNPART:
1887718744
case GGML_OP_GET_REL_POS:
1887818745
case GGML_OP_RWKV_WKV:
18879-
case GGML_OP_RWKV_TOKEN_SHIFT:
1888018746
case GGML_OP_MAP_UNARY:
1888118747
case GGML_OP_MAP_BINARY:
1888218748
case GGML_OP_MAP_CUSTOM1_F32:

0 commit comments

Comments
 (0)