diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 14207d5b..4a5616ff 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1085,7 +1085,9 @@ impl Model { f16_: _, } = self.hparams; - let mut buf_size = 512 * 1024 * 1024; + // For the first run, we need to guess a maximum buffer size so we can measure + // the actual memory consumption of the temporary ggml context. + let mut buf_size = 1024 * 1024 * 1024; if session.mem_per_token > 0 && session.mem_per_token * n > buf_size { // add 10% to account for ggml object overhead buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;