Skip to content

Commit c121b6e

Browse files
committed
kv-cache : relax SWA masking condition
ggml-ci
1 parent 1f7d50b commit c121b6e

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -557,43 +557,33 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
557557
continue;
558558
}
559559

560-
// keep track of what the minimum sequence positions would be if we accept the ubatch
561-
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
562-
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
563-
seq_pos_min[s] = cells.seq_pos_min(s);
564-
}
565-
566560
bool found = true;
567561
for (uint32_t i = 0; i < n_tokens; i++) {
568-
const llama_pos pos = ubatch.pos[i];
569-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
562+
//const llama_pos pos = ubatch.pos[i];
563+
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
570564

571565
// can we use this cell? either:
572566
// - the cell is empty
573567
// - the cell is occupied only by one sequence:
574-
// - mask causally, if the sequence is the same as the one we are inserting
568+
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
575569
// - mask SWA, using current max pos for that sequence in the cache
576570
// always insert in the cell with minimum pos
577571
bool can_use = cells.is_empty(head_cur + i);
578572

579573
if (!can_use && cells.seq_count(head_cur + i) == 1) {
580574
const llama_pos pos_cell = cells.pos_get(head_cur + i);
581575

582-
// causal mask
583-
if (cells.seq_has(head_cur + i, seq_id)) {
584-
can_use = pos_cell >= pos;
585-
}
576+
// (disabled) causal mask
577+
// note: it's better to purge any "future" tokens beforehand
578+
//if (cells.seq_has(head_cur + i, seq_id)) {
579+
// can_use = pos_cell >= pos;
580+
//}
586581

587582
if (!can_use) {
588583
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
589584

590585
// SWA mask
591-
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
592-
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
593-
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
594-
if (pos_cell == seq_pos_min[seq_id_cell] &&
595-
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
596-
seq_pos_min[seq_id_cell]++;
586+
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
597587
can_use = true;
598588
}
599589
}
@@ -621,8 +611,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
621611
}
622612

623613
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
614+
// keep track of the max sequence position that we would overwrite with this ubatch
615+
// for non-SWA cache, this would be always empty
616+
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
617+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
618+
seq_pos_max_rm[s] = -1;
619+
}
620+
624621
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
625622
if (!cells.is_empty(head_cur + i)) {
623+
assert(cells.seq_count(head_cur + i) == 1);
624+
625+
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
626+
const llama_pos pos = cells.pos_get(head_cur + i);
627+
628+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
629+
626630
cells.rm(head_cur + i);
627631
}
628632

@@ -633,6 +637,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
633637
}
634638
}
635639

640+
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
641+
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
642+
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
643+
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
644+
if (seq_pos_max_rm[s] == -1) {
645+
continue;
646+
}
647+
648+
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
649+
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
650+
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
651+
652+
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
653+
}
654+
}
655+
636656
// move the head at the end of the slot
637657
head = head_cur + ubatch.n_tokens;
638658
}

0 commit comments

Comments
 (0)