@@ -557,43 +557,33 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
557
557
continue ;
558
558
}
559
559
560
- // keep track of what the minimum sequence positions would be if we accept the ubatch
561
- llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
562
- for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
563
- seq_pos_min[s] = cells.seq_pos_min (s);
564
- }
565
-
566
560
bool found = true ;
567
561
for (uint32_t i = 0 ; i < n_tokens; i++) {
568
- const llama_pos pos = ubatch.pos [i];
569
- const llama_seq_id seq_id = ubatch.seq_id [i][0 ];
562
+ // const llama_pos pos = ubatch.pos[i];
563
+ // const llama_seq_id seq_id = ubatch.seq_id[i][0];
570
564
571
565
// can we use this cell? either:
572
566
// - the cell is empty
573
567
// - the cell is occupied only by one sequence:
574
- // - mask causally, if the sequence is the same as the one we are inserting
568
+ // - (disabled) mask causally, if the sequence is the same as the one we are inserting
575
569
// - mask SWA, using current max pos for that sequence in the cache
576
570
// always insert in the cell with minimum pos
577
571
bool can_use = cells.is_empty (head_cur + i);
578
572
579
573
if (!can_use && cells.seq_count (head_cur + i) == 1 ) {
580
574
const llama_pos pos_cell = cells.pos_get (head_cur + i);
581
575
582
- // causal mask
583
- if (cells.seq_has (head_cur + i, seq_id)) {
584
- can_use = pos_cell >= pos;
585
- }
576
+ // (disabled) causal mask
577
+ // note: it's better to purge any "future" tokens beforehand
578
+ // if (cells.seq_has(head_cur + i, seq_id)) {
579
+ // can_use = pos_cell >= pos;
580
+ // }
586
581
587
582
if (!can_use) {
588
583
const llama_seq_id seq_id_cell = cells.seq_get (head_cur + i);
589
584
590
585
// SWA mask
591
- // note: we insert only in the cell with minimum pos in order to preserve the invariant that
592
- // all positions between [pos_min, pos_max] for each sequence will be present in the cache
593
- // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
594
- if (pos_cell == seq_pos_min[seq_id_cell] &&
595
- is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
596
- seq_pos_min[seq_id_cell]++;
586
+ if (is_masked_swa (pos_cell, cells.seq_pos_max (seq_id_cell) + 1 )) {
597
587
can_use = true ;
598
588
}
599
589
}
@@ -621,8 +611,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
621
611
}
622
612
623
613
void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
614
+ // keep track of the max sequence position that we would overwrite with this ubatch
615
+ // for non-SWA cache, this would be always empty
616
+ llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
617
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
618
+ seq_pos_max_rm[s] = -1 ;
619
+ }
620
+
624
621
for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
625
622
if (!cells.is_empty (head_cur + i)) {
623
+ assert (cells.seq_count (head_cur + i) == 1 );
624
+
625
+ const llama_seq_id seq_id = cells.seq_get (head_cur + i);
626
+ const llama_pos pos = cells.pos_get (head_cur + i);
627
+
628
+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
629
+
626
630
cells.rm (head_cur + i);
627
631
}
628
632
@@ -633,6 +637,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
633
637
}
634
638
}
635
639
640
+ // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
641
+ // will be present in the cache. so we have to purge any position which is less than those we would overwrite
642
+ // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
643
+ for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
644
+ if (seq_pos_max_rm[s] == -1 ) {
645
+ continue ;
646
+ }
647
+
648
+ if (cells.seq_pos_min (s) <= seq_pos_max_rm[s]) {
649
+ LLAMA_LOG_DEBUG (" %s: purging positions [%d, %d] of sequence %d from KV cache\n " ,
650
+ __func__, cells.seq_pos_min (s), seq_pos_max_rm[s], s);
651
+
652
+ seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
653
+ }
654
+ }
655
+
636
656
// move the head at the end of the slot
637
657
head = head_cur + ubatch.n_tokens ;
638
658
}
0 commit comments