Skip to content

Commit 178230e

Browse files
committed
Getting to decode stage...
1 parent c78f9fc commit 178230e

File tree

2 files changed

+235
-120
lines changed

2 files changed

+235
-120
lines changed

ggml/src/ggml.c

Lines changed: 104 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3435,7 +3435,7 @@ struct ggml_tensor * ggml_reshape_4d(
34353435
int64_t ne2,
34363436
int64_t ne3) {
34373437
GGML_ASSERT(ggml_is_contiguous(a));
3438-
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
3438+
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
34393439

34403440
const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
34413441
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
@@ -5441,17 +5441,25 @@ struct ggml_tensor * ggml_delta_net(
54415441
GGML_ASSERT(ggml_is_contiguous(beta));
54425442
GGML_ASSERT(ggml_is_contiguous(state));
54435443

5444-
const int64_t S = k->ne[0];
5445-
const int64_t H = k->ne[1];
5444+
const int64_t S_k = k->ne[0];
5445+
const int64_t H_k = k->ne[1];
54465446
const int64_t n_tokens = k->ne[2];
54475447
const int64_t n_seqs = state->ne[1];
54485448

5449-
// Validate dimensions
5450-
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
5451-
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
5452-
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
5453-
GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs);
5454-
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
5449+
const int64_t S_v = v->ne[0];
5450+
const int64_t H_v = v->ne[1];
5451+
5452+
// Validate dimensions - allow different head dimensions for q/k vs v
5453+
GGML_ASSERT(v->ne[2] == n_tokens);
5454+
GGML_ASSERT(q->ne[2] == n_tokens);
5455+
GGML_ASSERT(g->ne[2] == n_tokens);
5456+
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && (beta->ne[2] == n_seqs || beta->ne[2] == 1));
5457+
GGML_ASSERT(ggml_nelements(state) == S_v * H_v * n_seqs);
5458+
5459+
// Check that q and k have the same dimensions
5460+
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
5461+
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
5462+
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens);
54555463

54565464
// Apply L2 normalization to query and key if requested
54575465
struct ggml_tensor * q_norm = q;
@@ -5466,69 +5474,117 @@ struct ggml_tensor * ggml_delta_net(
54665474

54675475
// Apply sigmoid to beta for gating
54685476
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
5469-
5470-
// Apply causal 1D convolution preprocessing to mixed QKV
5471-
// Concatenate q, k, v along the feature dimension
5472-
int64_t concat_ne[4] = { q->ne[0], q->ne[1], q->ne[2], q->ne[3] * 3 };
5473-
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 3);
5474-
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 3);
5475-
5476-
// Transpose for convolution: [S, H, n_tokens, n_seqs*3] -> [S, n_tokens, H, n_seqs*3]
5477-
mixed_qkv = ggml_permute(ctx, mixed_qkv, 0, 2, 1, 3);
5478-
5479-
// Apply causal 1D convolution
5480-
struct ggml_tensor * conv_out = ggml_conv_1d(
5481-
ctx,
5482-
conv_weight,
5483-
mixed_qkv,
5484-
1, // stride
5485-
conv_weight->ne[2] - 1, // padding (kernel_size - 1)
5486-
1 // dilation
5487-
);
5488-
5477+
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1);
5478+
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
5479+
5480+
u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
5481+
5482+
mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens);
5483+
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0);
5484+
5485+
// Apply SSM convolution
5486+
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
5487+
54895488
// Apply bias if provided
54905489
if (conv_bias) {
54915490
conv_out = ggml_add(ctx, conv_out, conv_bias);
54925491
}
5493-
5492+
54945493
// Apply SiLU activation
54955494
conv_out = ggml_silu(ctx, conv_out);
5496-
5497-
// Transpose back: [S, n_tokens, H, n_seqs*3] -> [S, H, n_tokens, n_seqs*3]
5495+
5496+
// Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
5497+
conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, 1, 1);
5498+
5499+
// Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
54985500
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
5501+
5502+
// q projection view
5503+
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
5504+
S_k, // ne0
5505+
H_k, // ne1
5506+
conv_out->ne[1], // ne2 = sequence length (1)
5507+
conv_out->ne[2], // ne3 = batch (1)
5508+
H_k * sizeof(float), // nb1 = stride along H_k
5509+
conv_out->nb[1], // nb2 = stride along sequence dim
5510+
conv_out->nb[2], // nb3 = stride along batch dim
5511+
0 // offset in bytes
5512+
);
5513+
5514+
// k projection view
5515+
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
5516+
S_k, // ne0
5517+
H_k, // ne1
5518+
conv_out->ne[1], // ne2
5519+
conv_out->ne[2], // ne3
5520+
H_k * sizeof(float), // nb1
5521+
conv_out->nb[1], // nb2
5522+
conv_out->nb[2], // nb3
5523+
S_k * H_k * sizeof(q->type) // offset = skip q_out
5524+
);
5525+
5526+
// v projection view
5527+
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
5528+
S_v, // ne0
5529+
H_v, // ne1
5530+
conv_out->ne[1], // ne2
5531+
conv_out->ne[2], // ne3
5532+
H_v * sizeof(float), // nb1
5533+
conv_out->nb[1], // nb2
5534+
conv_out->nb[2], // nb3
5535+
(2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
5536+
);
5537+
5538+
// Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
5539+
q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
5540+
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
5541+
v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
5542+
5543+
q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens);
5544+
k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, 1, n_tokens);
5545+
v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, 1, n_tokens);
54995546

5500-
// Split the convolved output back into q, k, v components
5501-
// Split along the last dimension (3 * original size)
5502-
int64_t split_size = q->ne[3];
5503-
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, q->ne[0], q->ne[1], q->ne[2], split_size,
5504-
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], 0);
5505-
5506-
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, k->ne[0], k->ne[1], k->ne[2], split_size,
5507-
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2],
5508-
split_size * ggml_type_size(q->type));
5547+
// NOW we repeat query and key to match value head dimensions if needed (after convolution)
5548+
struct ggml_tensor * q_broadcast = q_conv;
5549+
struct ggml_tensor * k_broadcast = k_conv;
55095550

5510-
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, v->ne[0], v->ne[1], v->ne[2], split_size,
5511-
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2],
5512-
2 * split_size * ggml_type_size(q->type));
5551+
if (H_k != H_v) {
5552+
// Calculate the repeat factor: H_v / H_k
5553+
GGML_ASSERT(H_v % H_k == 0);
5554+
int64_t repeat_factor = H_v / H_k;
5555+
5556+
// Repeat query and key along the head dimension
5557+
// First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
5558+
q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, 1, H_k, n_tokens);
5559+
k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, 1, H_k, n_tokens);
5560+
5561+
// Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
5562+
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens);
5563+
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens);
5564+
5565+
// Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
5566+
q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, 1);
5567+
k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, 1);
5568+
}
55135569

55145570
// concat output and new_state
5515-
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
5571+
const int64_t ne[4] = { S_v * H_v, n_tokens + H_v * n_seqs, 1, 1 };
55165572
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
55175573

55185574
// Set operation parameters for the delta rule computation
55195575
int32_t params[8] = {
55205576
chunk_size,
55215577
use_qk_l2norm ? 1 : 0,
55225578
0, 0, // reserved
5523-
0, 0, 0, 0 // scale and other params
5579+
0, 0, 0 // scale and other params
55245580
};
55255581
memcpy(params + 4, &scale, sizeof(float));
55265582
ggml_set_op_params(result, params, sizeof(params));
55275583

55285584
// Use custom operation for the gated delta rule computation
55295585
result->op = GGML_OP_DELTA_NET;
5530-
result->src[0] = q_conv;
5531-
result->src[1] = k_conv;
5586+
result->src[0] = q_broadcast;
5587+
result->src[1] = k_broadcast;
55325588
result->src[2] = v_conv;
55335589
result->src[3] = g;
55345590
result->src[4] = beta_sigmoid;

0 commit comments

Comments
 (0)