@@ -3435,7 +3435,7 @@ struct ggml_tensor * ggml_reshape_4d(
34353435 int64_t ne2 ,
34363436 int64_t ne3 ) {
34373437 GGML_ASSERT (ggml_is_contiguous (a ));
3438- GGML_ASSERT (ggml_nelements (a ) == ne0 * ne1 * ne2 * ne3 );
3438+ GGML_ASSERT (ggml_nelements (a ) == ne0 * ne1 * ne2 * ne3 );
34393439
34403440 const int64_t ne [4 ] = { ne0 , ne1 , ne2 , ne3 };
34413441 struct ggml_tensor * result = ggml_new_tensor_impl (ctx , a -> type , 4 , ne , a , 0 );
@@ -5441,17 +5441,25 @@ struct ggml_tensor * ggml_delta_net(
54415441 GGML_ASSERT (ggml_is_contiguous (beta ));
54425442 GGML_ASSERT (ggml_is_contiguous (state ));
54435443
5444- const int64_t S = k -> ne [0 ];
5445- const int64_t H = k -> ne [1 ];
5444+ const int64_t S_k = k -> ne [0 ];
5445+ const int64_t H_k = k -> ne [1 ];
54465446 const int64_t n_tokens = k -> ne [2 ];
54475447 const int64_t n_seqs = state -> ne [1 ];
54485448
5449- // Validate dimensions
5450- GGML_ASSERT (v -> ne [0 ] == S && v -> ne [1 ] == H && v -> ne [2 ] == n_tokens );
5451- GGML_ASSERT (q -> ne [0 ] == S && q -> ne [1 ] == H && q -> ne [2 ] == n_tokens );
5452- GGML_ASSERT (g -> ne [0 ] == S && g -> ne [1 ] == H && g -> ne [2 ] == n_tokens );
5453- GGML_ASSERT (beta -> ne [0 ] == H && beta -> ne [1 ] == n_tokens && beta -> ne [2 ] == n_seqs );
5454- GGML_ASSERT (ggml_nelements (state ) == S * S * H * n_seqs );
5449+ const int64_t S_v = v -> ne [0 ];
5450+ const int64_t H_v = v -> ne [1 ];
5451+
5452+ // Validate dimensions - allow different head dimensions for q/k vs v
5453+ GGML_ASSERT (v -> ne [2 ] == n_tokens );
5454+ GGML_ASSERT (q -> ne [2 ] == n_tokens );
5455+ GGML_ASSERT (g -> ne [2 ] == n_tokens );
5456+ GGML_ASSERT (beta -> ne [0 ] == H_v && beta -> ne [1 ] == n_tokens && (beta -> ne [2 ] == n_seqs || beta -> ne [2 ] == 1 ));
5457+ GGML_ASSERT (ggml_nelements (state ) == S_v * H_v * n_seqs );
5458+
5459+ // Check that q and k have the same dimensions
5460+ GGML_ASSERT (q -> ne [0 ] == S_k && q -> ne [1 ] == H_k && q -> ne [2 ] == n_tokens );
5461+ GGML_ASSERT (k -> ne [0 ] == S_k && k -> ne [1 ] == H_k && k -> ne [2 ] == n_tokens );
5462+ GGML_ASSERT (g -> ne [0 ] == S_v && g -> ne [1 ] == H_v && g -> ne [2 ] == n_tokens );
54555463
54565464 // Apply L2 normalization to query and key if requested
54575465 struct ggml_tensor * q_norm = q ;
@@ -5466,69 +5474,117 @@ struct ggml_tensor * ggml_delta_net(
54665474
54675475 // Apply sigmoid to beta for gating
54685476 struct ggml_tensor * beta_sigmoid = ggml_sigmoid (ctx , beta );
5469-
5470- // Apply causal 1D convolution preprocessing to mixed QKV
5471- // Concatenate q, k, v along the feature dimension
5472- int64_t concat_ne [4 ] = { q -> ne [0 ], q -> ne [1 ], q -> ne [2 ], q -> ne [3 ] * 3 };
5473- struct ggml_tensor * mixed_qkv = ggml_concat (ctx , q_norm , k_norm , 3 );
5474- mixed_qkv = ggml_concat (ctx , mixed_qkv , v , 3 );
5475-
5476- // Transpose for convolution: [S, H, n_tokens, n_seqs*3] -> [S, n_tokens, H, n_seqs*3]
5477- mixed_qkv = ggml_permute (ctx , mixed_qkv , 0 , 2 , 1 , 3 );
5478-
5479- // Apply causal 1D convolution
5480- struct ggml_tensor * conv_out = ggml_conv_1d (
5481- ctx ,
5482- conv_weight ,
5483- mixed_qkv ,
5484- 1 , // stride
5485- conv_weight -> ne [2 ] - 1 , // padding (kernel_size - 1)
5486- 1 // dilation
5487- );
5488-
5477+ struct ggml_tensor * mixed_qkv = ggml_concat (ctx , q_norm , k_norm , 1 );
5478+ mixed_qkv = ggml_concat (ctx , mixed_qkv , v , 1 );
5479+
5480+ u_int32_t dim = (S_v * H_v ) + 2 * (H_k * S_k );
5481+
5482+ mixed_qkv = ggml_reshape_3d (ctx , mixed_qkv , 1 , dim , n_tokens );
5483+ struct ggml_tensor * mixed_qkv_padded = ggml_pad (ctx , mixed_qkv , 3 , 0 , 0 , 0 );
5484+
5485+ // Apply SSM convolution
5486+ struct ggml_tensor * conv_out = ggml_ssm_conv (ctx , mixed_qkv_padded , conv_weight );
5487+
54895488 // Apply bias if provided
54905489 if (conv_bias ) {
54915490 conv_out = ggml_add (ctx , conv_out , conv_bias );
54925491 }
5493-
5492+
54945493 // Apply SiLU activation
54955494 conv_out = ggml_silu (ctx , conv_out );
5496-
5497- // Transpose back: [S, n_tokens, H, n_seqs*3] -> [S, H, n_tokens, n_seqs*3]
5495+
5496+ // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
5497+ conv_out = ggml_reshape_4d (ctx , conv_out , dim , n_tokens , 1 , 1 );
5498+
5499+ // Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
54985500 conv_out = ggml_permute (ctx , conv_out , 0 , 2 , 1 , 3 );
5501+
5502+ // q projection view
5503+ struct ggml_tensor * q_conv = ggml_view_4d (ctx , conv_out ,
5504+ S_k , // ne0
5505+ H_k , // ne1
5506+ conv_out -> ne [1 ], // ne2 = sequence length (1)
5507+ conv_out -> ne [2 ], // ne3 = batch (1)
5508+ H_k * sizeof (float ), // nb1 = stride along H_k
5509+ conv_out -> nb [1 ], // nb2 = stride along sequence dim
5510+ conv_out -> nb [2 ], // nb3 = stride along batch dim
5511+ 0 // offset in bytes
5512+ );
5513+
5514+ // k projection view
5515+ struct ggml_tensor * k_conv = ggml_view_4d (ctx , conv_out ,
5516+ S_k , // ne0
5517+ H_k , // ne1
5518+ conv_out -> ne [1 ], // ne2
5519+ conv_out -> ne [2 ], // ne3
5520+ H_k * sizeof (float ), // nb1
5521+ conv_out -> nb [1 ], // nb2
5522+ conv_out -> nb [2 ], // nb3
5523+ S_k * H_k * sizeof (q -> type ) // offset = skip q_out
5524+ );
5525+
5526+ // v projection view
5527+ struct ggml_tensor * v_conv = ggml_view_4d (ctx , conv_out ,
5528+ S_v , // ne0
5529+ H_v , // ne1
5530+ conv_out -> ne [1 ], // ne2
5531+ conv_out -> ne [2 ], // ne3
5532+ H_v * sizeof (float ), // nb1
5533+ conv_out -> nb [1 ], // nb2
5534+ conv_out -> nb [2 ], // nb3
5535+ (2 * S_k * H_k ) * sizeof (q -> type )// offset = skip q_out + k_out
5536+ );
5537+
5538+ // Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
5539+ q_conv = ggml_permute (ctx , q_conv , 0 , 2 , 1 , 3 );
5540+ k_conv = ggml_permute (ctx , k_conv , 0 , 2 , 1 , 3 );
5541+ v_conv = ggml_permute (ctx , v_conv , 0 , 2 , 1 , 3 );
5542+
5543+ q_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , q_conv ), S_k * H_k , 1 , n_tokens );
5544+ k_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , k_conv ), S_k * H_k , 1 , n_tokens );
5545+ v_conv = ggml_reshape_3d (ctx , ggml_cont (ctx , v_conv ), S_v * H_v , 1 , n_tokens );
54995546
5500- // Split the convolved output back into q, k, v components
5501- // Split along the last dimension (3 * original size)
5502- int64_t split_size = q -> ne [3 ];
5503- struct ggml_tensor * q_conv = ggml_view_4d (ctx , conv_out , q -> ne [0 ], q -> ne [1 ], q -> ne [2 ], split_size ,
5504- conv_out -> nb [0 ], conv_out -> nb [1 ], conv_out -> nb [2 ], 0 );
5505-
5506- struct ggml_tensor * k_conv = ggml_view_4d (ctx , conv_out , k -> ne [0 ], k -> ne [1 ], k -> ne [2 ], split_size ,
5507- conv_out -> nb [0 ], conv_out -> nb [1 ], conv_out -> nb [2 ],
5508- split_size * ggml_type_size (q -> type ));
5547+ // NOW we repeat query and key to match value head dimensions if needed (after convolution)
5548+ struct ggml_tensor * q_broadcast = q_conv ;
5549+ struct ggml_tensor * k_broadcast = k_conv ;
55095550
5510- struct ggml_tensor * v_conv = ggml_view_4d (ctx , conv_out , v -> ne [0 ], v -> ne [1 ], v -> ne [2 ], split_size ,
5511- conv_out -> nb [0 ], conv_out -> nb [1 ], conv_out -> nb [2 ],
5512- 2 * split_size * ggml_type_size (q -> type ));
5551+ if (H_k != H_v ) {
5552+ // Calculate the repeat factor: H_v / H_k
5553+ GGML_ASSERT (H_v % H_k == 0 );
5554+ int64_t repeat_factor = H_v / H_k ;
5555+
5556+ // Repeat query and key along the head dimension
5557+ // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
5558+ q_broadcast = ggml_reshape_4d (ctx , q_conv , S_k , 1 , H_k , n_tokens );
5559+ k_broadcast = ggml_reshape_4d (ctx , k_conv , S_k , 1 , H_k , n_tokens );
5560+
5561+ // Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
5562+ q_broadcast = ggml_repeat_4d (ctx , q_broadcast , S_k , repeat_factor , H_k , n_tokens );
5563+ k_broadcast = ggml_repeat_4d (ctx , k_broadcast , S_k , repeat_factor , H_k , n_tokens );
5564+
5565+ // Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
5566+ q_broadcast = ggml_reshape_4d (ctx , q_broadcast , S_k , H_v , n_tokens , 1 );
5567+ k_broadcast = ggml_reshape_4d (ctx , k_broadcast , S_k , H_v , n_tokens , 1 );
5568+ }
55135569
55145570 // concat output and new_state
5515- const int64_t ne [4 ] = { S * H , n_tokens + S * n_seqs , 1 , 1 };
5571+ const int64_t ne [4 ] = { S_v * H_v , n_tokens + H_v * n_seqs , 1 , 1 };
55165572 struct ggml_tensor * result = ggml_new_tensor (ctx , GGML_TYPE_F32 , 4 , ne );
55175573
55185574 // Set operation parameters for the delta rule computation
55195575 int32_t params [8 ] = {
55205576 chunk_size ,
55215577 use_qk_l2norm ? 1 : 0 ,
55225578 0 , 0 , // reserved
5523- 0 , 0 , 0 , 0 // scale and other params
5579+ 0 , 0 , 0 // scale and other params
55245580 };
55255581 memcpy (params + 4 , & scale , sizeof (float ));
55265582 ggml_set_op_params (result , params , sizeof (params ));
55275583
55285584 // Use custom operation for the gated delta rule computation
55295585 result -> op = GGML_OP_DELTA_NET ;
5530- result -> src [0 ] = q_conv ;
5531- result -> src [1 ] = k_conv ;
5586+ result -> src [0 ] = q_broadcast ;
5587+ result -> src [1 ] = k_broadcast ;
55325588 result -> src [2 ] = v_conv ;
55335589 result -> src [3 ] = g ;
55345590 result -> src [4 ] = beta_sigmoid ;
0 commit comments