Skip to content

Commit

Permalink
update gpt-2/main-backend.cpp from master
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Oct 21, 2023
1 parent 62a06c6 commit 2105a49
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/gpt-2/main-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct gpt2_model {
};

// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_ctx, int n_gpu_layers) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
Expand Down Expand Up @@ -338,6 +338,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}

// override the default training context with the user-provided
model.hparams.n_ctx = n_ctx;

// key + value memory
{
const auto & hparams = model.hparams;
Expand Down Expand Up @@ -859,7 +862,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
if (!gpt2_model_load(params.model, model, vocab, params.n_ctx, params.n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand Down Expand Up @@ -972,7 +975,7 @@ int main(int argc, char ** argv) {
fflush(stdout);

// end of text token
if (embd.back() == 50256) {
if (!params.ignore_eos && embd.back() == 50256) {
break;
}
}
Expand Down

0 comments on commit 2105a49

Please sign in to comment.