@@ -582,43 +582,33 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
582
582
continue ;
583
583
}
584
584
585
- // keep track of what the minimum sequence positions would be if we accept the ubatch
586
- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
587
- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
588
- seq_pos_min[s] = cells.seq_pos_min (s);
589
- }
590
-
591
585
bool found = true ;
592
586
for (uint32_t i = 0 ; i < n_tokens; i++) {
593
- const llama_pos pos = ubatch.pos [i];
594
- const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
587
+ // const llama_pos pos = ubatch.pos[i];
588
+ // const llama_seq_id seq_id = ubatch.seq_id[i][0];
595
589
596
590
// can we use this cell? either:
597
591
// - the cell is empty
598
592
// - the cell is occupied only by one sequence:
599
- // - mask causally, if the sequence is the same as the one we are inserting
593
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
600
594
// - mask SWA, using current max pos for that sequence in the cache
601
595
// always insert in the cell with minimum pos
602
596
bool can_use = cells.is_empty (head_cur + i);
603
597
604
598
if (!can_use && cells.seq_count (head_cur + i) == 1 ) {
605
599
const llama_pos pos_cell = cells.pos_get (head_cur + i);
606
600
607
- // causal mask
608
- if (cells.seq_has (head_cur + i, seq_id)) {
609
- can_use = pos_cell >= pos;
610
- }
601
+ // (disabled) causal mask
602
+ // note: it's better to purge any "future" tokens beforehand
603
+ // if (cells.seq_has(head_cur + i, seq_id)) {
604
+ // can_use = pos_cell >= pos;
605
+ // }
611
606
612
607
if (!can_use) {
613
608
const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i);
614
609
615
610
// SWA mask
616
- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
617
- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
618
- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
619
- if (pos_cell == seq_pos_min[seq_id_cell] &&
620
- is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
621
- seq_pos_min[seq_id_cell]++;
611
+ if (is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
622
612
can_use = true ;
623
613
}
624
614
}
@@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
646
636
}
647
637
648
638
void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
639
+ // keep track of the max sequence position that we would overwrite with this ubatch
640
+ // for non-SWA cache, this would be always empty
641
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643
+ seq_pos_max_rm[s] = -1 ;
644
+ }
645
+
649
646
for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
650
647
if (!cells.is_empty (head_cur + i)) {
648
+ assert (cells.seq_count (head_cur + i) == 1 );
649
+
650
+ const llama_seq_id seq_id = cells.seq_get (head_cur + i);
651
+ const llama_pos pos = cells.pos_get (head_cur + i);
652
+
653
+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
654
+
651
655
cells.rm (head_cur + i);
652
656
}
653
657
@@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
658
662
}
659
663
}
660
664
665
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
666
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
667
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
668
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
669
+ if (seq_pos_max_rm[s] == -1 ) {
670
+ continue ;
671
+ }
672
+
673
+ if (cells.seq_pos_min (s) <= seq_pos_max_rm[s]) {
674
+ LLAMA_LOG_DEBUG (" %s: purging positions [%d, %d] of sequence %d from KV cache\n " ,
675
+ __func__, cells.seq_pos_min (s), seq_pos_max_rm[s], s);
676
+
677
+ seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
678
+ }
679
+ }
680
+
661
681
// move the head at the end of the slot
662
682
head = head_cur + ubatch.n_tokens ;
663
683
}
0 commit comments