Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 21 additions & 40 deletions examples/reranker/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use std::time::Duration;

use anyhow::{bail, Context, Result};
use clap::Parser;
use hf_hub::api::sync::ApiBuilder;

use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
use llama_cpp_2::context::LlamaContext;
Expand Down Expand Up @@ -92,13 +91,11 @@ fn main() -> Result<()> {
.with_n_threads_batch(std::thread::available_parallelism()?.get().try_into()?)
.with_embeddings(true)
.with_pooling_type(pooling_type);
println!("ctx_params: {:?}", ctx_params);
println!("ctx_params: {ctx_params:?}");
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create the llama_context")?;

let n_embd = model.n_embd();

let prompt_lines = {
let mut lines = Vec::new();
for doc in documents {
Expand All @@ -108,13 +105,13 @@ fn main() -> Result<()> {
lines
};

println!("prompt_lines: {:?}", prompt_lines);
println!("prompt_lines: {prompt_lines:?}");
// tokenize the prompt
let tokens_lines_list = prompt_lines
.iter()
.map(|line| model.str_to_token(line, AddBos::Always))
.collect::<Result<Vec<_>, _>>()
.with_context(|| format!("failed to tokenize {:?}", prompt_lines))?;
.with_context(|| format!("failed to tokenize {prompt_lines:?}"))?;

let n_ctx = ctx.n_ctx() as usize;
let n_ctx_train = model.n_ctx_train();
Expand Down Expand Up @@ -156,7 +153,6 @@ fn main() -> Result<()> {
// } else {
// tokens_lines_list.len()
// };
let mut embeddings_stored = 0;
let mut max_seq_id_batch = 0;
let mut output = Vec::with_capacity(tokens_lines_list.len());

Expand All @@ -169,16 +165,10 @@ fn main() -> Result<()> {
&mut ctx,
&mut batch,
max_seq_id_batch,
n_embd,
&mut output,
normalise,
pooling.clone(),
&pooling,
)?;
embeddings_stored += if pooling == "none" {
batch.n_tokens()
} else {
max_seq_id_batch
};
max_seq_id_batch = 0;
batch.clear();
}
Expand All @@ -191,34 +181,23 @@ fn main() -> Result<()> {
&mut ctx,
&mut batch,
max_seq_id_batch,
n_embd,
&mut output,
normalise,
pooling.clone(),
&pooling,
)?;

let t_main_end = ggml_time_us();

for (j, embeddings) in output.iter().enumerate() {
if pooling == "none" {
eprintln!("embedding {j}: ");
for i in 0..n_embd as usize {
if !normalise {
eprint!("{:6.5} ", embeddings[i]);
} else {
eprint!("{:9.6} ", embeddings[i]);
}
}
eprintln!();
} else if pooling == "rank" {
if pooling == "rank" {
eprintln!("rerank score {j}: {:8.3}", embeddings[0]);
} else {
eprintln!("embedding {j}: ");
for i in 0..n_embd as usize {
if !normalise {
eprint!("{:6.5} ", embeddings[i]);
for embedding in embeddings {
if normalise {
eprint!("{embedding:9.6} ");
} else {
eprint!("{:9.6} ", embeddings[i]);
eprint!("{embedding:6.5} ");
}
}
eprintln!();
Expand All @@ -243,10 +222,9 @@ fn batch_decode(
ctx: &mut LlamaContext,
batch: &mut LlamaBatch,
s_batch: i32,
n_embd: i32,
output: &mut Vec<Vec<f32>>,
normalise: bool,
pooling: String,
pooling: &str,
) -> Result<()> {
eprintln!(
"{}: n_tokens = {}, n_seq = {}",
Expand All @@ -266,9 +244,9 @@ fn batch_decode(
.with_context(|| "Failed to get sequence embeddings")?;
let normalized = if normalise {
if pooling == "rank" {
normalize_embeddings(&embeddings, -1)
normalize_embeddings(embeddings, -1)
} else {
normalize_embeddings(&embeddings, 2)
normalize_embeddings(embeddings, 2)
}
} else {
embeddings.to_vec()
Expand All @@ -291,27 +269,30 @@ fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec<f32> {
0 => {
// max absolute
let max_abs = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max) / 32760.0;
max_abs as f64
f64::from(max_abs)
}
2 => {
// euclidean norm
input
.iter()
.map(|x| (*x as f64).powi(2))
.map(|x| f64::from(*x).powi(2))
.sum::<f64>()
.sqrt()
}
p => {
// p-norm
let sum = input.iter().map(|x| (x.abs() as f64).powi(p)).sum::<f64>();
sum.powf(1.0 / p as f64)
let sum = input
.iter()
.map(|x| f64::from(x.abs()).powi(p))
.sum::<f64>();
sum.powf(1.0 / f64::from(p))
}
};

let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 };

for i in 0..n {
output[i] = (input[i] as f64 * norm) as f32;
output[i] = (f64::from(input[i]) * norm) as f32;
}

output
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct Args {
fn parse_key_val(s: &str) -> Result<(String, ParamOverrideValue)> {
let pos = s
.find('=')
.ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?;
.ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{s}`"))?;
let key = s[..pos].parse()?;
let value: String = s[pos + 1..].parse()?;
let value = i64::from_str(&value)
Expand Down
1 change: 1 addition & 0 deletions llama-cpp-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ pub struct LogOptions {
impl LogOptions {
/// If enabled, logs are sent to tracing. If disabled, all logs are suppressed. Default is for
/// logs to be sent to tracing.
#[must_use]
pub fn with_logs_enabled(mut self, enabled: bool) -> Self {
self.disabled = !enabled;
self
Expand Down
24 changes: 12 additions & 12 deletions llama-cpp-2/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ pub struct LlamaLoraAdapter {
pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
}

/// A performance-friendly wrapper around [LlamaModel::chat_template] which is then
/// fed into [LlamaModel::apply_chat_template] to convert a list of messages into an LLM
/// prompt. Internally the template is stored as a CString to avoid round-trip conversions
/// A performance-friendly wrapper around [`LlamaModel::chat_template`] which is then
/// fed into [`LlamaModel::apply_chat_template`] to convert a list of messages into an LLM
/// prompt. Internally the template is stored as a `CString` to avoid round-trip conversions
/// within the FFI.
#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
pub struct LlamaChatTemplate(CString);
Expand All @@ -55,7 +55,7 @@ impl LlamaChatTemplate {
&self.0
}

/// Attempts to convert the CString into a Rust str reference.
/// Attempts to convert the `CString` into a Rust str reference.
pub fn to_str(&self) -> Result<&str, Utf8Error> {
self.0.to_str()
}
Expand Down Expand Up @@ -569,7 +569,7 @@ impl LlamaModel {

/// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
///
/// You supply this into [Self::apply_chat_template] to get back a string with the appropriate template
/// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
/// substitution applied to convert a list of messages into a prompt the LLM can use to complete
/// the chat.
///
Expand Down Expand Up @@ -666,11 +666,11 @@ impl LlamaModel {
/// There is many ways this can fail. See [`LlamaContextLoadError`] for more information.
// we intentionally do not derive Copy on `LlamaContextParams` to allow llama.cpp to change the type to be non-trivially copyable.
#[allow(clippy::needless_pass_by_value)]
pub fn new_context(
&self,
pub fn new_context<'a>(
&'a self,
_: &LlamaBackend,
params: LlamaContextParams,
) -> Result<LlamaContext, LlamaContextLoadError> {
) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
let context_params = params.context_params;
let context = unsafe {
llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
Expand All @@ -681,14 +681,14 @@ impl LlamaModel {
}

/// Apply the models chat template to some messages.
/// See https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
/// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
///
/// Unlike the llama.cpp apply_chat_template which just randomly uses the ChatML template when given
/// Unlike the llama.cpp `apply_chat_template` which just randomly uses the ChatML template when given
/// a null pointer for the template, this requires an explicit template to be specified. If you want to
/// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
/// string.
///
/// Use [Self::chat_template] to retrieve the template baked into the model (this is the preferred
/// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
/// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
///
/// You probably want to set `add_ass` to true so that the generated template string ends with a the
Expand Down Expand Up @@ -764,7 +764,7 @@ where
let mut buffer = vec![0u8; capacity];

// call the foreign function
let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
if result < 0 {
return Err(MetaValError::NegativeReturn(result));
}
Expand Down
4 changes: 2 additions & 2 deletions llama-cpp-2/src/model/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl LlamaModelParams {
/// assert_eq!(count, 0);
/// ```
#[must_use]
pub fn kv_overrides(&self) -> KvOverrides {
pub fn kv_overrides<'a>(&'a self) -> KvOverrides<'a> {
KvOverrides::new(self)
}

Expand Down Expand Up @@ -235,7 +235,7 @@ impl LlamaModelParams {
);

// There should be some way to do this without iterating over everything.
for (_i, &c) in key.to_bytes_with_nul().iter().enumerate() {
for &c in key.to_bytes_with_nul().iter() {
c_char::try_from(c).expect("invalid character in key");
}

Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-2/src/model/params/kv_overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct KvOverrides<'a> {
}

impl KvOverrides<'_> {
pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides {
pub(super) fn new<'a>(model_params: &'a LlamaModelParams) -> KvOverrides<'a> {
KvOverrides { model_params }
}
}
Expand Down
24 changes: 12 additions & 12 deletions llama-cpp-2/src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ impl LlamaSampler {

/// Penalizes tokens for being present in the context.
///
/// Parameters:
/// Parameters:
/// - ``penalty_last_n``: last n tokens to penalize (0 = disable penalty, -1 = context size)
/// - ``penalty_repeat``: 1.0 = disabled
/// - ``penalty_freq``: 0.0 = disabled
Expand Down Expand Up @@ -415,15 +415,15 @@ impl LlamaSampler {
/// - ``n_vocab``: [`LlamaModel::n_vocab`]
/// - ``seed``: Seed to initialize random generation with.
/// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
/// generated text. A higher value corresponds to more surprising or less predictable text,
/// while a lower value corresponds to less surprising or more predictable text.
/// generated text. A higher value corresponds to more surprising or less predictable text,
/// while a lower value corresponds to less surprising or more predictable text.
/// - ``eta``: The learning rate used to update `mu` based on the error between the target and
/// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
/// updated more quickly, while a smaller learning rate will result in slower updates.
/// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
/// updated more quickly, while a smaller learning rate will result in slower updates.
/// - ``m``: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary
/// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
/// In the paper, they use `m = 100`, but you can experiment with different values to see how
/// it affects the performance of the algorithm.
/// value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
/// In the paper, they use `m = 100`, but you can experiment with different values to see how
/// it affects the performance of the algorithm.
#[must_use]
pub fn mirostat(n_vocab: i32, seed: u32, tau: f32, eta: f32, m: i32) -> Self {
let sampler =
Expand All @@ -436,11 +436,11 @@ impl LlamaSampler {
/// # Parameters:
/// - ``seed``: Seed to initialize random generation with.
/// - ``tau``: The target cross-entropy (or surprise) value you want to achieve for the
/// generated text. A higher value corresponds to more surprising or less predictable text,
/// while a lower value corresponds to less surprising or more predictable text.
/// generated text. A higher value corresponds to more surprising or less predictable text,
/// while a lower value corresponds to less surprising or more predictable text.
/// - ``eta``: The learning rate used to update `mu` based on the error between the target and
/// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
/// updated more quickly, while a smaller learning rate will result in slower updates.
/// observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
/// updated more quickly, while a smaller learning rate will result in slower updates.
#[must_use]
pub fn mirostat_v2(seed: u32, tau: f32, eta: f32) -> Self {
let sampler = unsafe { llama_cpp_sys_2::llama_sampler_init_mirostat_v2(seed, tau, eta) };
Expand Down
2 changes: 1 addition & 1 deletion llama-cpp-sys-2/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ fn main() {

if matches!(target_os, TargetOs::Linux)
&& target_triple.contains("aarch64")
&& !env::var(format!("CARGO_FEATURE_{}", "native".to_uppercase())).is_ok()
&& env::var(format!("CARGO_FEATURE_{}", "native".to_uppercase())).is_err()
{
// If the native feature is not enabled, we take off the native ARM64 support.
// It is useful in docker environments where the native feature is not enabled.
Expand Down
1 change: 1 addition & 0 deletions llama-cpp-sys-2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(unpredictable_function_pointer_comparisons)]

include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
Loading