@@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
81
81
}
82
82
}
83
83
84
- static void batch_decode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
84
+ static void batch_encode (llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
85
85
// clear previous kv_cache values (irrelevant for embeddings)
86
86
llama_kv_self_clear (ctx);
87
87
88
88
// run model
89
89
LOG_INF (" %s: n_tokens = %d, n_seq = %d\n " , __func__, batch.n_tokens , n_seq);
90
- if (llama_decode (ctx, batch) < 0 ) {
91
- LOG_ERR (" %s : failed to decode \n " , __func__);
90
+ if (llama_encode (ctx, batch) < 0 ) {
91
+ LOG_ERR (" %s : failed to encode \n " , __func__);
92
92
}
93
93
94
94
for (int i = 0 ; i < batch.n_tokens ; i++) {
@@ -233,7 +233,7 @@ int main(int argc, char ** argv) {
233
233
// encode if at capacity
234
234
if (batch.n_tokens + n_toks > n_batch) {
235
235
float * out = emb + p * n_embd;
236
- batch_decode (ctx, batch, out, s, n_embd);
236
+ batch_encode (ctx, batch, out, s, n_embd);
237
237
common_batch_clear (batch);
238
238
p += s;
239
239
s = 0 ;
@@ -246,7 +246,7 @@ int main(int argc, char ** argv) {
246
246
247
247
// final batch
248
248
float * out = emb + p * n_embd;
249
- batch_decode (ctx, batch, out, s, n_embd);
249
+ batch_encode (ctx, batch, out, s, n_embd);
250
250
251
251
// save embeddings to chunks
252
252
for (int i = 0 ; i < n_chunks; i++) {
@@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
267
267
batch_add_seq (query_batch, query_tokens, 0 );
268
268
269
269
std::vector<float > query_emb (n_embd, 0 );
270
- batch_decode (ctx, query_batch, query_emb.data (), 1 , n_embd);
270
+ batch_encode (ctx, query_batch, query_emb.data (), 1 , n_embd);
271
271
272
272
common_batch_clear (query_batch);
273
273
0 commit comments