Skip to content

Commit f715a85

Browse files
committed
tests: Initial unit tests for memory hierarchy
These only test the basics so far, but should allow for more expansive tests to come. Branch: MemoryTests Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent a4090d1 commit f715a85

File tree

1 file changed

+175
-0
lines changed

1 file changed

+175
-0
lines changed

tests/test-memory.cpp

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*------------------------------------------------------------------------------
2+
* Unit tests for llama-memory.h and derived memory implementations. It contains
3+
* a number of tests which can be run all together or separately.
4+
*
5+
* USAGE: ./bin/test-memory <test_name1> <test_name2>
6+
*
7+
* When adding a new test, do the following:
8+
*
9+
* 1. Add the new test_<memory_type>_description function under the
10+
* appropriate memory type section
11+
*
12+
* 2. Add `RUN_TEST(test_<memory_type>_description);` to main
13+
*----------------------------------------------------------------------------*/
14+
15+
#include "../src/llama-arch.h"
16+
#include "../src/llama-batch.h"
17+
#include "../src/llama-hparams.h"
18+
#include "../src/llama-impl.h"
19+
#include "../src/llama-kv-cache.h"
20+
#include "../src/llama-model.h"
21+
22+
#include "common.h"
23+
#include "llama.h"
24+
25+
#include <algorithm>
26+
#include <cstdio>
27+
#include <memory>
28+
29+
/*- Helpers ------------------------------------------------------------------*/
30+
31+
static std::shared_ptr<llama_model> _make_model(
32+
llm_arch arch = LLM_ARCH_LLAMA,
33+
uint32_t n_layer = 4,
34+
uint32_t n_embd_head_k = 4,
35+
uint32_t n_embd_head_v = 4,
36+
uint32_t n_head = 8,
37+
uint32_t n_head_kv = 2) {
38+
39+
llama_model_params params;
40+
params.tensor_buft_overrides = nullptr;
41+
std::shared_ptr<llama_model> model(new llama_model(params));
42+
model->hparams = llama_hparams();
43+
model->arch = arch;
44+
45+
model->hparams.n_layer = n_layer;
46+
model->hparams.n_embd_head_k = n_embd_head_k;
47+
model->hparams.n_embd_head_v = n_embd_head_v;
48+
49+
// If set to 0, assume the test will fill out the array elementwise (hybrid)
50+
if (n_head > 0) {
51+
auto& n_head_arr = model->hparams.n_head_arr;
52+
std::fill(n_head_arr.begin(), n_head_arr.end(), n_head);
53+
}
54+
if (n_head_kv > 0) {
55+
auto& n_head_kv_arr = model->hparams.n_head_kv_arr;
56+
std::fill(n_head_kv_arr.begin(), n_head_kv_arr.end(), n_head_kv);
57+
}
58+
59+
return model;
60+
}
61+
62+
struct log_scope {
63+
const char * name;
64+
explicit log_scope(const char * name) : name(name) {
65+
LLAMA_LOG_INFO("--------\n");
66+
LLAMA_LOG_INFO("START: %s\n", name);
67+
}
68+
~log_scope() {
69+
LLAMA_LOG_INFO("END: %s\n", name);
70+
LLAMA_LOG_INFO("--------\n");
71+
}
72+
};
73+
74+
#define RUN_TEST(test_name) \
75+
do { \
76+
bool run_test = argc < 2; \
77+
std::vector<std::string> args(argv + 1, argv + argc); \
78+
if (std::find(args.begin(), args.end(), #test_name) != args.end()) \
79+
run_test = true; \
80+
if (run_test) { \
81+
log_scope __log_scope(#test_name); \
82+
test_name(); \
83+
} \
84+
} while (0)
85+
86+
/*- Unified Cache ------------------------------------------------------------*/
87+
88+
/* Test that the unified cache can be constructed and destructed safely */
89+
static void test_llama_kv_cache_unified_constructor() {
90+
auto model = _make_model();
91+
llama_kv_cache_unified cache(
92+
/* model */ *model,
93+
/* filter */ nullptr,
94+
/* type_k */ GGML_TYPE_F32,
95+
/* type_v */ GGML_TYPE_F16,
96+
/* v_trans */ false,
97+
/* offload */ false,
98+
/* kv_size */ 10,
99+
/* padding */ 10,
100+
/* n_swa */ 0,
101+
/* swa_type */ LLAMA_SWA_TYPE_NONE
102+
);
103+
}
104+
105+
/* Test that the unified cache can operate with a single seq */
106+
static void test_llama_kv_cache_unified_single_seq() {
107+
auto model = _make_model();
108+
llama_kv_cache_unified cache(
109+
/* model */ *model,
110+
/* filter */ nullptr,
111+
/* type_k */ GGML_TYPE_F32,
112+
/* type_v */ GGML_TYPE_F16,
113+
/* v_trans */ false,
114+
/* offload */ false,
115+
/* kv_size */ 10,
116+
/* padding */ 10,
117+
/* n_swa */ 0,
118+
/* swa_type */ LLAMA_SWA_TYPE_NONE
119+
);
120+
GGML_ASSERT(cache.get_used_cells() == 0);
121+
122+
// Create the micro batch with a single 3-token sequence
123+
//
124+
// NOTE: A bunch of these asserts were just me figuring out how the batches
125+
// relate to each other, but they're left for future readers to help in the
126+
// same understanding process.
127+
llama_seq_id seq_id = 42;
128+
llama_batch batch = llama_batch_init(3, 0, 1);
129+
common_batch_add(batch, 101, 0, {seq_id}, false);
130+
common_batch_add(batch, 1, 1, {seq_id}, false);
131+
common_batch_add(batch, 102, 2, {seq_id}, false);
132+
llama_sbatch sbatch(batch, 0, true, false);
133+
GGML_ASSERT(batch.n_tokens == 3);
134+
GGML_ASSERT(sbatch.n_tokens == 3);
135+
GGML_ASSERT(!sbatch.seq.empty());
136+
llama_ubatch ubatch = sbatch.split_simple(4);
137+
printf("ubatch.n_seqs=%d\n", ubatch.n_seqs);
138+
GGML_ASSERT(ubatch.n_seqs == 3);
139+
GGML_ASSERT(ubatch.n_seq_tokens == 1);
140+
GGML_ASSERT(ubatch.n_tokens == 3);
141+
GGML_ASSERT(ubatch.seq_id[0][0] == seq_id);
142+
GGML_ASSERT(ubatch.seq_id[1][0] == seq_id);
143+
GGML_ASSERT(ubatch.seq_id[2][0] == seq_id);
144+
145+
// Find a slot for a new sequence
146+
GGML_ASSERT(cache.find_slot(ubatch));
147+
148+
// Clean up
149+
llama_batch_free(batch);
150+
}
151+
152+
/*- Recurrent Cache ----------------------------------------------------------*/
153+
154+
/* Test that the recurrent cache can be constructed and destructed safely */
155+
static void test_llama_kv_cache_recurrent_constructor() {
156+
auto model = _make_model(LLM_ARCH_MAMBA);
157+
llama_kv_cache_recurrent cache(
158+
/* model */ *model,
159+
/* type_k */ GGML_TYPE_F32,
160+
/* type_v */ GGML_TYPE_F16,
161+
/* offload */ false,
162+
/* kv_size */ 10
163+
);
164+
}
165+
166+
/*- Main ---------------------------------------------------------------------*/
167+
168+
int main(int argc, char* argv[]) {
169+
// Unified Cache Tests
170+
RUN_TEST(test_llama_kv_cache_unified_constructor);
171+
RUN_TEST(test_llama_kv_cache_unified_single_seq);
172+
// Recurrent Cache Tests
173+
RUN_TEST(test_llama_kv_cache_recurrent_constructor);
174+
return 0;
175+
}

0 commit comments

Comments
 (0)