@@ -1007,7 +1007,8 @@ struct llama_layer {
10071007};
10081008
10091009struct  llama_kv_cell  {
1010-     llama_pos pos = -1 ;
1010+     llama_pos pos   = -1 ;
1011+     llama_pos delta = 0 ;
10111012
10121013    std::set<llama_seq_id> seq_id;
10131014
@@ -1018,7 +1019,7 @@ struct llama_kv_cell {
10181019
10191020//  ring-buffer of cached KV data
10201021struct  llama_kv_cache  {
1021-     bool  is_roped  = false ;
1022+     bool  has_shift  = false ;
10221023
10231024    uint32_t  head = 0 ;
10241025    uint32_t  size = 0 ;
@@ -1223,6 +1224,8 @@ static bool llama_kv_cache_init(
12231224    const  int64_t  n_mem      = n_layer*n_ctx;
12241225    const  int64_t  n_elements = n_embd*n_mem;
12251226
1227+     cache.has_shift  = false ;
1228+ 
12261229    cache.head  = 0 ;
12271230    cache.size  = n_ctx;
12281231
@@ -1333,9 +1336,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
13331336    }
13341337}
13351338
1336- void  llama_kv_cache_rm_seq (struct  llama_kv_cache  & cache, llama_seq_id seq_id) {
1339+ void  llama_kv_cache_rm_seq (
1340+              struct  llama_kv_cache  & cache,
1341+                       llama_seq_id   seq_id,
1342+                          llama_pos   p0,
1343+                          llama_pos   p1) {
13371344    for  (uint32_t  i = 0 ; i < cache.size ; ++i) {
1338-         if  (cache.cells [i].has_seq_id (seq_id)) {
1345+         if  (cache.cells [i].has_seq_id (seq_id) && cache. cells [i]. pos  >= p0 && cache. cells [i]. pos  < p1 ) {
13391346            cache.cells [i].seq_id .erase (seq_id);
13401347            if  (cache.cells [i].seq_id .empty ()) {
13411348                cache.cells [i].pos  = -1 ;
@@ -1353,18 +1360,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
13531360    }
13541361}
13551362
1356- void  llama_kv_cache_shift (
1357-                struct  llama_context  & ctx ,
1363+ void  llama_kv_cache_shift_seq (
1364+              struct  llama_kv_cache  & cache ,
13581365                      llama_seq_id   seq_id,
13591366                         llama_pos   p0,
13601367                         llama_pos   p1,
13611368                         llama_pos   delta) {
1362-     auto  & hparams = ctx.model .hparams ;
1363-     auto  & cache   = ctx.kv_self ;
1364- 
13651369    for  (uint32_t  i = 0 ; i < cache.size ; ++i) {
13661370        if  (cache.cells [i].has_seq_id (seq_id) && cache.cells [i].pos  >= p0 && cache.cells [i].pos  < p1) {
13671371            cache.cells [i].pos  += delta;
1372+             if  (cache.cells [i].pos  < 0 ) {
1373+                 cache.cells [i].pos  = -1 ;
1374+                 cache.cells [i].seq_id .clear ();
1375+             } else  {
1376+                 cache.has_shift  = true ;
1377+                 cache.cells [i].delta  = delta;
1378+             }
13681379        }
13691380    }
13701381}
@@ -2595,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
25952606    const  int32_t  n_tokens = batch.n_tokens ;
25962607    const  int32_t  n_kv     = llama_kv_cache_cell_max (kv_self);
25972608
2609+     const  bool  do_rope_shift = kv_self.has_shift  || ggml_allocr_is_measure (lctx.alloc );
2610+ 
25982611    auto  & buf_compute = lctx.buf_compute ;
25992612
26002613    struct  ggml_init_params  params = {
@@ -2698,6 +2711,16 @@ static struct ggml_cgraph * llm_build_llama(
26982711        }
26992712    }
27002713
2714+     //  K_shift
2715+     struct  ggml_tensor  * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
2716+     ggml_allocr_alloc (lctx.alloc , K_shift);
2717+     if  (!ggml_allocr_is_measure (lctx.alloc )) {
2718+         int  * data = (int  *) K_shift->data ;
2719+         for  (int  i = 0 ; i < n_ctx; ++i) {
2720+             data[i] = kv_self.cells [i].delta ;
2721+         }
2722+     }
2723+ 
27012724    for  (int  il = 0 ; il < n_layer; ++il) {
27022725        ggml_format_name (inpL, " layer_inp_%d"  , il);
27032726
@@ -2723,6 +2746,17 @@ static struct ggml_cgraph * llm_build_llama(
27232746            ggml_set_name (cur, " attention_norm_0"  );
27242747        }
27252748
2749+         if  (do_rope_shift) {
2750+             ggml_build_forward_expand (gf,
2751+                     ggml_rope_custom_inplace (ctx0,
2752+                         ggml_view_3d (ctx0, kv_self.k ,
2753+                             n_embd_head, n_head_kv, n_ctx,
2754+                             ggml_element_size (kv_self.k )*n_embd_head,
2755+                             ggml_element_size (kv_self.k )*n_embd_gqa,
2756+                             ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
2757+                         K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale));
2758+         }
2759+ 
27262760        //  self-attention
27272761        {
27282762            //  compute Q and K and RoPE them
@@ -4033,7 +4067,8 @@ static bool llama_eval_internal(
40334067#endif 
40344068
40354069    //  update the kv ring buffer
4036-     lctx.kv_self .head  += n_tokens;
4070+     lctx.kv_self .head       += n_tokens;
4071+     lctx.kv_self .has_shift   = false ;
40374072
40384073#ifdef  GGML_PERF
40394074    //  print timing information per ggml operation (for debugging purposes)
@@ -6562,10 +6597,6 @@ struct llama_context * llama_new_context_with_model(
65626597            return  nullptr ;
65636598        }
65646599
6565-         if  (model->arch  == LLM_ARCH_LLAMA) {
6566-             ctx->kv_self .is_roped  = true ;
6567-         }
6568- 
65696600        {
65706601            const  size_t  memory_size = ggml_nbytes (ctx->kv_self .k ) + ggml_nbytes (ctx->kv_self .v );
65716602            LLAMA_LOG_INFO (" %s: kv self size  = %7.2f MB\n "  , __func__, memory_size / 1024.0  / 1024.0 );
@@ -6803,16 +6834,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
68036834    llama_kv_cache_rm_tokens (ctx->kv_self , c0, c1);
68046835}
68056836
6806- void  llama_kv_cache_rm_seq (struct  llama_context  * ctx, llama_seq_id seq_id) {
6807-     llama_kv_cache_rm_seq (ctx->kv_self , seq_id);
6837+ void  llama_kv_cache_rm_seq (struct  llama_context  * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1 ) {
6838+     llama_kv_cache_rm_seq (ctx->kv_self , seq_id, p0, p1 );
68086839}
68096840
68106841void  llama_kv_cache_keep_seq (struct  llama_context  * ctx, llama_seq_id seq_id) {
68116842    llama_kv_cache_keep_seq (ctx->kv_self , seq_id);
68126843}
68136844
6814- void  llama_kv_cache_shift (struct  llama_context  * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6815-     llama_kv_cache_shift (* ctx, seq_id, p0, p1, delta);
6845+ void  llama_kv_cache_shift_seq (struct  llama_context  * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
6846+     llama_kv_cache_shift_seq ( ctx-> kv_self , seq_id, p0, p1, delta);
68166847}
68176848
68186849//  Returns the *maximum* size of the state
0 commit comments