@@ -2817,7 +2817,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2817
2817
"GET_REL_POS",
2818
2818
"ADD_REL_POS",
2819
2819
"RWKV_WKV",
2820
- "RWKV_TOKEN_SHIFT",
2821
2820
2822
2821
"UNARY",
2823
2822
@@ -2836,7 +2835,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
2836
2835
"CROSS_ENTROPY_LOSS_BACK",
2837
2836
};
2838
2837
2839
- static_assert(GGML_OP_COUNT == 76 , "GGML_OP_COUNT != 76 ");
2838
+ static_assert(GGML_OP_COUNT == 75 , "GGML_OP_COUNT != 75 ");
2840
2839
2841
2840
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2842
2841
"none",
@@ -2906,8 +2905,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2906
2905
"win_unpart(x)",
2907
2906
"get_rel_pos(x)",
2908
2907
"add_rel_pos(x)",
2909
- "rwkv_wkv(k, v, r, tf, td, s, sq)",
2910
- "rwkv_token_shift(xc, xn, sq)",
2908
+ "rwkv_wkv(k, v, r, tf, td, s)",
2911
2909
2912
2910
"unary(x)",
2913
2911
@@ -2926,7 +2924,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2926
2924
"cross_entropy_loss_back(x,y)",
2927
2925
};
2928
2926
2929
- static_assert(GGML_OP_COUNT == 76 , "GGML_OP_COUNT != 76 ");
2927
+ static_assert(GGML_OP_COUNT == 75 , "GGML_OP_COUNT != 75 ");
2930
2928
2931
2929
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
2932
2930
@@ -7494,39 +7492,36 @@ struct ggml_tensor * ggml_rwkv_wkv(
7494
7492
struct ggml_tensor * r,
7495
7493
struct ggml_tensor * tf,
7496
7494
struct ggml_tensor * td,
7497
- struct ggml_tensor * state,
7498
- struct ggml_tensor * state_seq) {
7495
+ struct ggml_tensor * state) {
7499
7496
GGML_ASSERT(ggml_is_contiguous(k));
7500
7497
GGML_ASSERT(ggml_is_contiguous(v));
7501
7498
GGML_ASSERT(ggml_is_contiguous(r));
7502
7499
GGML_ASSERT(ggml_is_contiguous(tf));
7503
7500
GGML_ASSERT(ggml_is_contiguous(td));
7504
7501
GGML_ASSERT(ggml_is_contiguous(state));
7505
- GGML_ASSERT(ggml_is_contiguous(state_seq));
7506
- GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7507
7502
7508
7503
const int64_t S = k->ne[0];
7509
7504
const int64_t H = k->ne[2];
7510
7505
const int64_t n_tokens = k->ne[3];
7511
- const int64_t n_kv = state_seq ->ne[0 ];
7506
+ const int64_t n_seqs = state ->ne[1 ];
7512
7507
{
7513
7508
GGML_ASSERT(k->ne[1] == 1);
7514
7509
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
7515
7510
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
7516
7511
// TODO: RWKV v4 and v5
7517
7512
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
7518
- GGML_ASSERT(ggml_nelements(state) == S * S * H * n_kv );
7513
+ GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs );
7519
7514
}
7520
7515
7521
7516
bool is_node = false;
7522
7517
7523
- if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad || state_seq->grad ) {
7518
+ if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7524
7519
GGML_ABORT("fatal error"); // TODO: implement backward
7525
7520
is_node = true;
7526
7521
}
7527
7522
7528
7523
// concat output and new_state
7529
- const int64_t ne[4] = { S * H, n_tokens + S * n_kv , 1, 1 };
7524
+ const int64_t ne[4] = { S * H, n_tokens + S * n_seqs , 1, 1 };
7530
7525
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7531
7526
7532
7527
result->op = GGML_OP_RWKV_WKV;
@@ -7537,48 +7532,6 @@ struct ggml_tensor * ggml_rwkv_wkv(
7537
7532
result->src[3] = tf;
7538
7533
result->src[4] = td;
7539
7534
result->src[5] = state;
7540
- result->src[6] = state_seq;
7541
-
7542
- return result;
7543
- }
7544
-
7545
- // ggml_rwkv_token_shift
7546
-
7547
- struct ggml_tensor * ggml_rwkv_token_shift(
7548
- struct ggml_context * ctx,
7549
- struct ggml_tensor * x_carry,
7550
- struct ggml_tensor * x_norm,
7551
- struct ggml_tensor * state_seq) {
7552
- GGML_ASSERT(ggml_is_contiguous(x_carry));
7553
- GGML_ASSERT(ggml_is_contiguous(x_norm));
7554
- GGML_ASSERT(ggml_is_contiguous(state_seq));
7555
- GGML_ASSERT(state_seq->type == GGML_TYPE_I32);
7556
-
7557
- const int64_t n_embd = x_norm->ne[0];
7558
- const int64_t n_kv = state_seq->ne[0];
7559
- const int64_t n_tokens = state_seq->ne[1];
7560
- {
7561
- GGML_ASSERT(x_norm->ne[0] == n_embd);
7562
- GGML_ASSERT(x_norm->ne[1] == n_tokens);
7563
- GGML_ASSERT(ggml_nelements(x_carry) == n_embd * n_kv);
7564
- }
7565
-
7566
- bool is_node = false;
7567
-
7568
- if (x_carry->grad || x_norm->grad || state_seq->grad) {
7569
- GGML_ABORT("fatal error"); // TODO: implement backward
7570
- is_node = true;
7571
- }
7572
-
7573
- // concat output and new_state
7574
- const int64_t ne[4] = { n_embd, n_tokens + n_kv, 1, 1 };
7575
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7576
-
7577
- result->op = GGML_OP_RWKV_TOKEN_SHIFT;
7578
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7579
- result->src[0] = x_carry;
7580
- result->src[1] = x_norm;
7581
- result->src[2] = state_seq;
7582
7535
7583
7536
return result;
7584
7537
}
@@ -16418,7 +16371,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16418
16371
const size_t T = dst->src[1]->ne[3];
16419
16372
const size_t C = dst->ne[0];
16420
16373
const size_t H = dst->src[1]->ne[2];
16421
- const size_t n_kv = dst->src[6 ]->ne[0 ];
16374
+ const size_t n_seqs = dst->src[5 ]->ne[1 ];
16422
16375
16423
16376
float * dst_data = (float *) dst->data;
16424
16377
float * state = ((float *) dst->data) + C * T;
@@ -16434,8 +16387,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16434
16387
float * r = (float *) dst->src[2]->data;
16435
16388
float * time_faaaa = (float *) dst->src[3]->data;
16436
16389
float * time_decay = (float *) dst->src[4]->data;
16437
- int32_t * seq_data = (int32_t *) dst->src[6]->data;
16438
- memcpy(state, dst->src[5]->data, (C / H) * C * n_kv * sizeof(float));
16390
+ memcpy(state, dst->src[5]->data, (C / H) * C * n_seqs * sizeof(float));
16439
16391
16440
16392
size_t t_stride = H * (C / H);
16441
16393
@@ -16448,7 +16400,7 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16448
16400
// recursive through each token
16449
16401
for (size_t t = 0; t < T; t++) {
16450
16402
size_t t_offset = t * t_stride;
16451
- float * state_cur = state + (C / H) * C * seq_data[t * n_kv] ;
16403
+ float * state_cur = state + (C / H) * C * (t / (T / n_seqs)) ;
16452
16404
16453
16405
for (size_t h = 0; h < H; h++) {
16454
16406
size_t h_offset = h * h_stride;
@@ -16480,15 +16432,6 @@ static void ggml_compute_forward_rwkv_wkv_f32(
16480
16432
}
16481
16433
}
16482
16434
}
16483
-
16484
- for (size_t t = 0; t < T; t++) {
16485
- for (size_t kv = 1; kv < n_kv; kv++) {
16486
- int64_t seq = seq_data[t * n_kv + kv];
16487
- if (seq >= 0 && seq_data[(t + 1) * n_kv + kv] != seq) {
16488
- memcpy(state + (C / H) * C * seq, state + (C / H) * C * seq_data[t * n_kv], (C / H) * C * sizeof(float));
16489
- }
16490
- }
16491
- }
16492
16435
}
16493
16436
16494
16437
static void ggml_compute_forward_rwkv_wkv(
@@ -16509,77 +16452,6 @@ static void ggml_compute_forward_rwkv_wkv(
16509
16452
}
16510
16453
}
16511
16454
16512
- static void ggml_compute_forward_rwkv_token_shift_f32(
16513
- const struct ggml_compute_params * params,
16514
- struct ggml_tensor * dst) {
16515
- const int64_t n_embd = dst->ne[0];
16516
- const int64_t n_kv = dst->src[2]->ne[0];
16517
- const int64_t n_tokens = dst->src[1]->ne[1];
16518
- float * dst_data = (float *) dst->data;
16519
- float * x_carry = (float *) dst->src[0]->data;
16520
- float * x_norm = (float *) dst->src[1]->data;
16521
- int32_t * sq_data = (int32_t *) dst->src[2]->data;
16522
-
16523
- if (params->ith != 0) {
16524
- return;
16525
- }
16526
-
16527
- int32_t seq_start = 0;
16528
- int32_t seq_length = 0;
16529
-
16530
- for (int i1 = 0; i1 < n_kv; ++i1) {
16531
- seq_start = -1;
16532
- // assume that the tokens for each sequence are contiguous
16533
- for (int i2 = 0; i2 < n_tokens; ++i2) {
16534
- int32_t seq = sq_data[i2*n_kv];
16535
- if (seq == i1 && seq_start < 0) {
16536
- seq_start = i2;
16537
- }
16538
-
16539
- if ((seq_start >= 0 && seq != i1) || i2 == n_tokens - 1) {
16540
- seq_length = i2 - seq_start + (i2 == n_tokens - 1);
16541
- break;
16542
- }
16543
- }
16544
-
16545
- if (seq_start >= 0) {
16546
- int32_t seq = sq_data[seq_start*n_kv];
16547
- memcpy(dst_data + seq_start*n_embd, x_carry + seq*n_embd, n_embd*sizeof(float));
16548
- memcpy(dst_data + (seq_start+1)*n_embd, x_norm + seq_start*n_embd, (seq_length-1)*n_embd*sizeof(float));
16549
- }
16550
- }
16551
-
16552
- for (int i3 = 0; i3 < n_kv; ++i3) {
16553
- int32_t last_token_pos = 0;
16554
- for (int i4 = 0; i4 < n_tokens; ++i4) {
16555
- for (int i5 = 0; i5 < n_kv; ++i5) {
16556
- if (sq_data[i4*n_kv + i5] == i3) {
16557
- last_token_pos = i4;
16558
- }
16559
- }
16560
- }
16561
- memcpy(dst_data + (n_tokens + i3)*n_embd, x_norm + last_token_pos*n_embd, n_embd*sizeof(float));
16562
- }
16563
- }
16564
-
16565
- static void ggml_compute_forward_rwkv_token_shift(
16566
- const struct ggml_compute_params * params,
16567
- struct ggml_tensor * dst) {
16568
-
16569
- const struct ggml_tensor * src0 = dst->src[0];
16570
-
16571
- switch (src0->type) {
16572
- case GGML_TYPE_F32:
16573
- {
16574
- ggml_compute_forward_rwkv_token_shift_f32(params, dst);
16575
- } break;
16576
- default:
16577
- {
16578
- GGML_ABORT("fatal error");
16579
- }
16580
- }
16581
- }
16582
-
16583
16455
// ggml_compute_forward_map_unary
16584
16456
16585
16457
static void ggml_compute_forward_map_unary_f32(
@@ -17230,10 +17102,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
17230
17102
{
17231
17103
ggml_compute_forward_rwkv_wkv(params, tensor);
17232
17104
} break;
17233
- case GGML_OP_RWKV_TOKEN_SHIFT:
17234
- {
17235
- ggml_compute_forward_rwkv_token_shift(params, tensor);
17236
- } break;
17237
17105
case GGML_OP_MAP_UNARY:
17238
17106
{
17239
17107
ggml_unary_op_f32_t fun;
@@ -18305,7 +18173,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
18305
18173
case GGML_OP_GET_REL_POS:
18306
18174
case GGML_OP_ADD_REL_POS:
18307
18175
case GGML_OP_RWKV_WKV:
18308
- case GGML_OP_RWKV_TOKEN_SHIFT:
18309
18176
case GGML_OP_MAP_UNARY:
18310
18177
case GGML_OP_MAP_BINARY:
18311
18178
case GGML_OP_MAP_CUSTOM1_F32:
@@ -18876,7 +18743,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
18876
18743
case GGML_OP_WIN_UNPART:
18877
18744
case GGML_OP_GET_REL_POS:
18878
18745
case GGML_OP_RWKV_WKV:
18879
- case GGML_OP_RWKV_TOKEN_SHIFT:
18880
18746
case GGML_OP_MAP_UNARY:
18881
18747
case GGML_OP_MAP_BINARY:
18882
18748
case GGML_OP_MAP_CUSTOM1_F32:
0 commit comments