Skip to content

Commit 89a184f

Browse files
authored
kv-cache : relax SWA masking condition (#14119)
ggml-ci
1 parent 2baf077 commit 89a184f

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -582,43 +582,33 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
582582
continue;
583583
}
584584

585-
// keep track of what the minimum sequence positions would be if we accept the ubatch
586-
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
587-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
588-
seq_pos_min[s] = cells.seq_pos_min(s);
589-
}
590-
591585
bool found = true;
592586
for (uint32_t i = 0; i < n_tokens; i++) {
593-
const llama_pos pos = ubatch.pos[i];
594-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
587+
//const llama_pos pos = ubatch.pos[i];
588+
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
595589

596590
// can we use this cell? either:
597591
// - the cell is empty
598592
// - the cell is occupied only by one sequence:
599-
// - mask causally, if the sequence is the same as the one we are inserting
593+
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
600594
// - mask SWA, using current max pos for that sequence in the cache
601595
// always insert in the cell with minimum pos
602596
bool can_use = cells.is_empty(head_cur + i);
603597

604598
if (!can_use && cells.seq_count(head_cur + i) == 1) {
605599
const llama_pos pos_cell = cells.pos_get(head_cur + i);
606600

607-
// causal mask
608-
if (cells.seq_has(head_cur + i, seq_id)) {
609-
can_use = pos_cell >= pos;
610-
}
601+
// (disabled) causal mask
602+
// note: it's better to purge any "future" tokens beforehand
603+
//if (cells.seq_has(head_cur + i, seq_id)) {
604+
// can_use = pos_cell >= pos;
605+
//}
611606

612607
if (!can_use) {
613608
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
614609

615610
// SWA mask
616-
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
617-
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
618-
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
619-
if (pos_cell == seq_pos_min[seq_id_cell] &&
620-
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
621-
seq_pos_min[seq_id_cell]++;
611+
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
622612
can_use = true;
623613
}
624614
}
@@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
646636
}
647637

648638
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
639+
// keep track of the max sequence position that we would overwrite with this ubatch
640+
// for non-SWA cache, this would be always empty
641+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643+
seq_pos_max_rm[s] = -1;
644+
}
645+
649646
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
650647
if (!cells.is_empty(head_cur + i)) {
648+
assert(cells.seq_count(head_cur + i) == 1);
649+
650+
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
651+
const llama_pos pos = cells.pos_get(head_cur + i);
652+
653+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
654+
651655
cells.rm(head_cur + i);
652656
}
653657

@@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
658662
}
659663
}
660664

665+
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
666+
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
667+
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
668+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
669+
if (seq_pos_max_rm[s] == -1) {
670+
continue;
671+
}
672+
673+
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
674+
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
675+
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
676+
677+
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
678+
}
679+
}
680+
661681
// move the head at the end of the slot
662682
head = head_cur + ubatch.n_tokens;
663683
}

0 commit comments

Comments
 (0)