Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference support for T5 and FLAN-T5 model families #8141

Merged
merged 18 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
if (params.warmup) {
LOG("warming up the model with an empty run\n");

std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model);
llama_token eos = llama_token_eos(model);
// some models (e.g. T5) don't have a BOS token
if (bos != -1) {
tmp.push_back(bos);
}
tmp.push_back(eos);

if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
}
tmp.clear();
tmp.push_back(decoder_start_token_id);
}
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
Expand Down
46 changes: 36 additions & 10 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2775,29 +2775,47 @@ def write_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("T5ForConditionalGeneration")
@Model.register("T5WithLMHeadModel")
@Model.register("T5ForConditionalGeneration")
@Model.register("MT5ForConditionalGeneration")
@Model.register("UMT5ForConditionalGeneration")
class T5Model(Model):
model_arch = gguf.MODEL_ARCH.T5

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shared_token_embeddings_found = False

def set_vocab(self):
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model

tokenizer_path = self.dir_model / 'spiece.model'
tokenizer_path = self.dir_model / 'tokenizer.model'

# many older models use spiece.model tokenizer model filename
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'spiece.model'

if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}")

sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())

# some models like Pile-T5 family use BPE tokenizer instead of Unigram
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
# assure the tokenizer model file name is correct
assert tokenizer_path.name == 'tokenizer.model'
return self._set_vocab_sentencepiece()
else:
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM

add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM

tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
Expand Down Expand Up @@ -2867,7 +2885,10 @@ def set_vocab(self):

def set_gguf_parameters(self):
self.gguf_writer.add_name("T5")
self.gguf_writer.add_context_length(self.hparams["n_positions"])
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
n_ctx = 512
self.gguf_writer.add_context_length(n_ctx)
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
Expand All @@ -2883,12 +2904,17 @@ def set_gguf_parameters(self):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
return []
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
# and decoder and ignore the remaining ones.
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
if not self.shared_token_embeddings_found:
name = "shared.weight"
self.shared_token_embeddings_found = True
else:
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
return []

return [(self.map_tensor_name(name), data_torch)]

Expand Down
21 changes: 20 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
}

const bool add_bos = llama_should_add_bos_token(model);
GGML_ASSERT(llama_add_eos_token(model) != 1);
if (!llama_model_has_encoder(model)) {
GGML_ASSERT(llama_add_eos_token(model) != 1);
}
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
Expand Down Expand Up @@ -517,6 +519,23 @@ int main(int argc, char ** argv) {
exit(1);
}

if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();

if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = llama_token_bos(model);
}
embd_inp.clear();
embd_inp.push_back(decoder_start_token_id);
}

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
Expand Down
15 changes: 15 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,13 @@ extern "C" {
// Get a llama model tensor
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);

// Returns true if the model contains an encoder that requires llama_encode() call
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);

// For encoder-decoder models, this function returns id of the token that must be provided
// to the decoder to start generating output sequence. For other models, it returns -1.
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
ggerganov marked this conversation as resolved.
Show resolved Hide resolved

// Returns 0 on success
LLAMA_API uint32_t llama_model_quantize(
const char * fname_inp,
Expand Down Expand Up @@ -768,6 +775,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);

// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my case, a prompt consists of a static part, which is unchanged and makes use of the KV cache, and dynamic part, which changes frequently. It works good with GPT, where I can call llama_kv_cache_seq_rm to cleanup the dynamic part of KV cache and start evaluating again. Would a similar approach work with T5? In other words, what's the degree of control over the encoder output? Thank you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vladfaust No, encoder requires all input tokens to be present in the input batch. It's because the attention in encoder is not causal, so each token in the input sequence attends to all tokens in the input sequence. It doesn't even use KV cache because there's no need to.

I guess theoretically it would be possible to implement it in a way that would allow "adding" tokens to encoder output by calling llama_encode() multiple times, but the implementation would be much more complicated, definitely outside the scope of this PR.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, @fairydreaming: one of my use-cases is converting a growing chat history to some structured representation for each new message. Do I understand correctly that for now I'd have to encode the whole history again and again for each inference without any form of caching? (No offence, obviously, as I'm very grateful for the T5 support at all!)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vladfaust Yes, there's no caching in the encoder, so if the input sequence grows even by one token, you have to encode it again and during this process all previous calculations for this token sequence are repeated.

// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
Expand Down
Loading
Loading