Skip to content

Commit 503dda2

Browse files
committed
memory : improve status handling
1 parent efe0bc9 commit 503dda2

7 files changed

+69
-27
lines changed

src/llama-context.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ bool llama_context::kv_self_update(bool optimize) {
451451

452452
const auto kv_state = kv_self->init_update(this, optimize);
453453
if (kv_state->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE) {
454-
// no updates have been performed
454+
// no updates need to be performed
455455
return false;
456456
}
457457

@@ -979,6 +979,7 @@ int llama_context::decode(llama_batch & inp_batch) {
979979
case LLAMA_MEMORY_STATUS_NO_UPDATE:
980980
{
981981
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
982+
982983
return -2;
983984
}
984985
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
@@ -993,12 +994,14 @@ int llama_context::decode(llama_batch & inp_batch) {
993994
}
994995
}
995996

996-
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
997+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
997998

998999
return 1;
9991000
}
10001001
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
10011002
{
1003+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
1004+
10021005
return -2;
10031006
}
10041007
}

src/llama-kv-cache-recurrent.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llama-kv-cache-recurrent.h"
22

33
#include "llama-impl.h"
4+
#include "llama-io.h"
45
#include "llama-batch.h"
56
#include "llama-model.h"
67

src/llama-kv-cache-unified-iswa.cpp

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
167167
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
168168
state_base = kv->get_base()->init_full();
169169
state_swa = kv->get_swa ()->init_full();
170+
171+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
170172
}
171173

172174
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
@@ -176,22 +178,7 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
176178
state_base = kv->get_base()->init_update(lctx, optimize);
177179
state_swa = kv->get_swa ()->init_update(lctx, optimize);
178180

179-
// TODO: this is very ugly - how to make it simpler?
180-
// the llama_memory_status enum is not very well designed
181-
if (state_base->get_status() != LLAMA_MEMORY_STATUS_SUCCESS && state_base->get_status() != LLAMA_MEMORY_STATUS_NO_UPDATE) {
182-
status = state_base->get_status();
183-
return;
184-
}
185-
186-
if (state_swa->get_status() != LLAMA_MEMORY_STATUS_SUCCESS && state_swa->get_status() != LLAMA_MEMORY_STATUS_NO_UPDATE) {
187-
status = state_swa->get_status();
188-
return;
189-
}
190-
191-
if (state_base->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE && state_swa->get_status() == LLAMA_MEMORY_STATUS_NO_UPDATE) {
192-
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
193-
return;
194-
}
181+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
195182
}
196183

197184
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
@@ -200,13 +187,15 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
200187
std::vector<uint32_t> heads_base,
201188
std::vector<uint32_t> heads_swa,
202189
std::vector<llama_ubatch> ubatches)
203-
: status(LLAMA_MEMORY_STATUS_SUCCESS),
204-
sbatch(std::move(sbatch)),
205-
ubatches(std::move(ubatches)) {
206-
// note: here we copy the ubatches. not sure if this is ideal
207-
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
208-
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
209-
}
190+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
191+
sbatch(std::move(sbatch)),
192+
ubatches(std::move(ubatches)) {
193+
// note: here we copy the ubatches. not sure if this is ideal
194+
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
195+
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
196+
197+
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
198+
}
210199

211200
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
212201

@@ -246,6 +235,7 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
246235

247236
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
248237
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
238+
249239
return ubatches[i_next];
250240
}
251241

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "llama-kv-cache-unified.h"
22

33
#include "llama-impl.h"
4+
#include "llama-io.h"
45
#include "llama-model.h"
56
#include "llama-context.h"
67

src/llama-kv-cache.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#pragma once
22

33
#include "llama.h"
4-
#include "llama-io.h"
54
#include "llama-memory.h"
65

6+
class llama_io_write_i;
7+
class llama_io_read_i;
8+
79
struct llama_kv_cache : public llama_memory_i {
810
virtual ~llama_kv_cache() = default;
911

src/llama-memory.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,42 @@
11
#include "llama-memory.h"
2+
3+
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
4+
bool has_update = false;
5+
6+
switch (s0) {
7+
case LLAMA_MEMORY_STATUS_SUCCESS:
8+
{
9+
has_update = true;
10+
break;
11+
}
12+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
13+
{
14+
break;
15+
}
16+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
17+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
18+
{
19+
return s0;
20+
}
21+
}
22+
23+
switch (s1) {
24+
case LLAMA_MEMORY_STATUS_SUCCESS:
25+
{
26+
has_update = true;
27+
break;
28+
}
29+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
30+
{
31+
break;
32+
}
33+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
34+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
35+
{
36+
return s1;
37+
}
38+
}
39+
40+
// if either status has an update, then the combined status has an update
41+
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
42+
}

src/llama-memory.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ enum llama_memory_status {
4545
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
4646
};
4747

48+
// helper function for combining the status of two memory states
49+
// useful for implementing hybrid memory types (e.g. iSWA)
50+
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
51+
4852
// the interface for managing the memory state during batch processing
4953
// this interface is implemented per memory type. see:
5054
// - llama_kv_cache_unified_state
@@ -72,7 +76,7 @@ class llama_memory_state_i {
7276
// get the current ubatch
7377
virtual const llama_ubatch & get_ubatch() const = 0;
7478

75-
// get the status of the memory state
79+
// get the status of the memory state - used for error handling and checking if any updates would be applied
7680
virtual llama_memory_status get_status() const = 0;
7781
};
7882

0 commit comments

Comments
 (0)