Skip to content

Commit

Permalink
llama : assert pooling tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 4, 2024
1 parent 79e4eed commit fc9af15
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6246,6 +6246,7 @@ struct llm_build_context {

// final output
cur = inpL;
cb(cur, "result_embd", -1);

// pooling layer
switch (pooling_type) {
Expand All @@ -6256,17 +6257,18 @@ struct llm_build_context {
case LLAMA_POOLING_TYPE_MEAN:
{
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_CLS:
{
cur = ggml_get_rows(ctx0, cur, inp_cls);
cb(cur, "result_embd_pooled", -1);
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "Invalid pooling type");
} break;
}
cb(cur, "result_embd", -1);

ggml_build_forward_expand(gf, cur);

Expand Down Expand Up @@ -8281,7 +8283,7 @@ static int llama_decode_internal(
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0);
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
} else {
if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or the third to last tensor
Expand Down Expand Up @@ -8413,6 +8415,8 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);

// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
Expand Down

0 comments on commit fc9af15

Please sign in to comment.