Skip to content

Commit

Permalink
Update llama.cpp and integrate graph "planning"
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Jul 12, 2023
1 parent 7c0a29a commit 53095b1
Show file tree
Hide file tree
Showing 24 changed files with 251 additions and 212 deletions.
2 changes: 1 addition & 1 deletion binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ impl Generate {
memory_v_type: mem_typ,
use_gpu: self.use_gpu,
n_batch: self.batch_size,
n_threads: self.num_threads(),
}
}

Expand All @@ -349,7 +350,6 @@ impl Generate {

pub fn inference_parameters(&self, eot: llm::TokenId) -> InferenceParameters {
InferenceParameters {
n_threads: self.num_threads(),
sampler: Arc::new(llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
Expand Down
13 changes: 3 additions & 10 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,10 @@ fn perplexity(args: &cli_args::Perplexity) -> Result<()> {
let model = args.model_load.load(args.generate.use_gpu)?;
let (mut session, _) =
snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config);
let parameters = args.generate.inference_parameters(model.eot_token_id());

session.perplexity(
model.as_ref(),
&parameters,
prompt.as_str(),
|chunk, perplexity| {
println!("Perplexity[{chunk}]: {perplexity}");
},
)?;
session.perplexity(model.as_ref(), prompt.as_str(), |chunk, perplexity| {
println!("Perplexity[{chunk}]: {perplexity}");
})?;

Ok(())
}
Expand Down Expand Up @@ -273,7 +267,6 @@ fn interactive(
let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None);
if let Err(InferenceError::ContextFull) = session.feed_prompt(
model.as_ref(),
&parameters,
&prompt,
// OutputRequest
&mut Default::default(),
Expand Down
2 changes: 1 addition & 1 deletion binaries/llm-test/src/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn feed_prompt(
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, always_continue)
session.feed_prompt(model, prompt, output, always_continue)
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Expand Down
17 changes: 6 additions & 11 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::{convert::Infallible, sync::Arc};

use llm::InferenceStats;
use llm::{InferenceSessionConfig, InferenceStats};

use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta};

Expand All @@ -15,14 +15,11 @@ pub(crate) fn can_infer(
expected_output: Option<&str>,
maximum_token_count: usize,
) -> anyhow::Result<TestCaseReport> {
let mut session = model.start_session(Default::default());
let (actual_output, res) = run_inference(
model,
model_config,
&mut session,
input,
maximum_token_count,
);
let mut session = model.start_session(InferenceSessionConfig {
n_threads: model_config.threads,
..Default::default()
});
let (actual_output, res) = run_inference(model, &mut session, input, maximum_token_count);

// Process the results
Ok(TestCaseReport {
Expand Down Expand Up @@ -58,7 +55,6 @@ pub(crate) fn can_infer(

fn run_inference(
model: &dyn llm::Model,
model_config: &ModelConfig,
session: &mut llm::InferenceSession,
input: &str,
maximum_token_count: usize,
Expand All @@ -70,7 +66,6 @@ fn run_inference(
&llm::InferenceRequest {
prompt: input.into(),
parameters: &llm::InferenceParameters {
n_threads: model_config.threads,
sampler: Arc::new(DeterministicSampler),
},
play_back_previous_tokens: false,
Expand Down
4 changes: 1 addition & 3 deletions binaries/llm-test/src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ fn feed_prompt(
model: &impl Model,
output: &mut OutputRequest,
) -> Result<(), llm::InferenceError> {
session.feed_prompt(model, &Default::default(), prompt, output, |x| {
always_continue(x)
})
session.feed_prompt(model, prompt, output, always_continue)
}

fn always_continue(_: &[u8]) -> Result<InferenceFeedback, Infallible> {
Expand Down
9 changes: 1 addition & 8 deletions crates/ggml/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{os::raw::c_int, ptr::NonNull, sync::Arc};

use memmap2::Mmap;

use crate::{sys, usize_to_i32, usize_to_i64, Buffer, ComputationGraph, Tensor, Type};
use crate::{sys, usize_to_i32, usize_to_i64, Buffer, Tensor, Type};

/// Acts as a RAII-guard over a `sys::ggml_context`, allocating via
/// `ggml_init` and dropping via `ggml_free`.
Expand Down Expand Up @@ -442,13 +442,6 @@ impl Context {
self.new_tensor_raw(tensor)
}

/// Computes the specified graph. Must be run in order to evaluate the graph.
pub fn graph_compute(&self, graph: &mut ComputationGraph) {
unsafe {
sys::ggml_graph_compute(self.ptr.as_ptr(), &mut graph.inner);
}
}

/// Retrieves the memory used by this [Context].
pub fn used_mem(&self) -> usize {
unsafe { sys::ggml_used_mem(self.ptr.as_ptr()) }
Expand Down
59 changes: 56 additions & 3 deletions crates/ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! It exposes a subset of operations (currently used to implement the [llm](https://crates.io/crates/llm) library).
//! Note that it does not expose a fully-idiomatic safe Rust interface; operations that could be potentially unsafe are marked as such.
//!
//! `ggml` operates on a computational graph; no values will be computed until [Context::graph_compute] is executed.
//! `ggml` operates on a computational graph; no values will be computed until the [Context] is executed via an [GraphExecutionPlan].
//! All [Tensor]s are nodes in this computational graph, and values cannot be retrieved until computation is completed.
#![deny(missing_docs)]

Expand Down Expand Up @@ -221,6 +221,8 @@ pub enum Type {
F16,
/// Float 32-bit.
F32,
/// Integer 8-bit.
I8,
}
impl From<Type> for sys::ggml_type {
fn from(t: Type) -> Self {
Expand All @@ -239,6 +241,7 @@ impl From<Type> for sys::ggml_type {
Type::I32 => sys::ggml_type_GGML_TYPE_I32,
Type::F16 => sys::ggml_type_GGML_TYPE_F16,
Type::F32 => sys::ggml_type_GGML_TYPE_F32,
Type::I8 => sys::ggml_type_GGML_TYPE_I8,
}
}
}
Expand All @@ -260,6 +263,7 @@ impl TryFrom<sys::ggml_type> for Type {
sys::ggml_type_GGML_TYPE_I32 => Ok(Type::I32),
sys::ggml_type_GGML_TYPE_F16 => Ok(Type::F16),
sys::ggml_type_GGML_TYPE_F32 => Ok(Type::F32),
sys::ggml_type_GGML_TYPE_I8 => Ok(Type::I8),

_ => Err(()),
}
Expand All @@ -282,6 +286,7 @@ impl std::fmt::Display for Type {
Type::I32 => write!(f, "i32"),
Type::F16 => write!(f, "f16"),
Type::F32 => write!(f, "f32"),
Type::I8 => write!(f, "i8"),
}
}
}
Expand All @@ -303,6 +308,7 @@ impl Type {
Type::I32 => false,
Type::F16 => false,
Type::F32 => false,
Type::I8 => false,
}
}
}
Expand Down Expand Up @@ -351,10 +357,9 @@ pub struct ComputationGraph {

impl ComputationGraph {
/// Create a new [ComputationGraph] with the specified `n_threads`.
pub fn new(n_threads: usize) -> Self {
pub fn new() -> Self {
Self {
inner: sys::ggml_cgraph {
n_threads: usize_to_i32(n_threads),
// SAFETY: This should be safe to zero. The original C++ impl
// just leaves it uninitialized
..unsafe { std::mem::zeroed::<sys::ggml_cgraph>() }
Expand All @@ -368,6 +373,54 @@ impl ComputationGraph {
}
}

impl Default for ComputationGraph {
fn default() -> Self {
Self::new()
}
}

/// A `ggml` execution plan. Contains the information needed to execute a computation graph.
pub struct GraphExecutionPlan {
inner: sys::ggml_cplan,
inner_graph: sys::ggml_cgraph,
}

impl GraphExecutionPlan {
/// Create a new [GraphExecutionPlan] from a [ComputationGraph] and the number of threads to use.
pub fn new(graph: &mut ComputationGraph, n_threads: usize) -> Self {
Self {
inner: unsafe { sys::ggml_graph_plan(&mut graph.inner, usize_to_i32(n_threads)) },
inner_graph: graph.inner,
}
}

/// Creates a [Type::I8] work buffer with size `plan.work_size` for this [GraphExecutionPlan] in the given [Context].
fn create_work_buffer(&mut self, context: &Context) -> Tensor {
context.new_tensor_1d(Type::I8, self.inner.work_size)
}

/// Assign a work buffer to this [GraphExecutionPlan].
fn assign_work_buffer(&mut self, buffer: &mut Tensor) {
assert!(
buffer.get_type() == Type::I8,
"Work buffer must be of type i8"
);
unsafe {
self.inner.work_data = buffer.data().cast();
}
}

/// Execute this [GraphExecutionPlan] in the given [Context].
pub fn execute(&mut self, context: &Context) {
let mut work_buffer = self.create_work_buffer(context);
self.assign_work_buffer(&mut work_buffer);

unsafe {
sys::ggml_graph_compute(&mut self.inner_graph, &mut self.inner);
}
}
}

/// The size of `t` as bytes.
pub fn type_size(t: Type) -> usize {
unsafe { sys::ggml_type_size(t.into()) }
Expand Down
Loading

0 comments on commit 53095b1

Please sign in to comment.