Skip to content

Commit

Permalink
feat: wire up quantize for CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Apr 6, 2023
1 parent 4e90696 commit 600de36
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"rust-analyzer.cargo.features": ["convert"]
"rust-analyzer.cargo.features": ["convert", "quantize"]
}
2 changes: 1 addition & 1 deletion llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llama-rs = { path = "../llama-rs", features = ["convert"] }
llama-rs = { path = "../llama-rs", features = ["convert", "quantize"] }

rand = { workspace = true }

Expand Down
16 changes: 15 additions & 1 deletion llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub enum Args {
///
/// For reference, see [the PR](https://github.com/rustformers/llama-rs/pull/83).
Convert(Box<Convert>),

/// Quantize a GGML model to 4-bit.
Quantize(Box<Quantize>),
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -244,7 +247,7 @@ fn parse_bias(s: &str) -> Result<TokenBias, String> {
pub struct ModelLoad {
/// Where to load the model path from
#[arg(long, short = 'm')]
pub model_path: String,
pub model_path: PathBuf,

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
Expand Down Expand Up @@ -367,6 +370,17 @@ pub struct Convert {
pub element_type: ElementType,
}

#[derive(Parser, Debug)]
pub struct Quantize {
/// The path to the model to quantize
#[arg()]
pub source: PathBuf,

/// The path to save the quantized model to
#[arg()]
pub destination: PathBuf,
}

#[derive(Parser, Debug, ValueEnum, Clone, Copy)]
pub enum ElementType {
/// Quantized 4-bit (type 0).
Expand Down
13 changes: 13 additions & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fn main() {
Args::Repl(args) => interactive(&args, false),
Args::ChatExperimental(args) => interactive(&args, true),
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()),
Args::Quantize(args) => quantize(&args),
}
}

Expand Down Expand Up @@ -191,6 +192,18 @@ fn interactive(
}
}

fn quantize(args: &cli_args::Quantize) {
llama_rs::quantize::quantize(
&args.source,
&args.destination,
llama_rs::ElementType::Q4_0,
|p| {
println!("{p:?}");
},
)
.unwrap();
}

fn load_prompt_file_with_prompt(
prompt_file: &cli_args::PromptFile,
prompt: Option<&str>,
Expand Down
7 changes: 5 additions & 2 deletions llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ rust-version = "1.65"
ggml = { path = "../ggml" }

bytemuck = "1.13.1"
half = "2.2.1"
partial_sort = "0.2.0"
thiserror = "1.0"
rand = { workspace = true }
Expand All @@ -23,5 +22,9 @@ serde_json = { version = "1.0.94", optional = true }
protobuf = { version = "= 2.14.0", optional = true }
rust_tokenizers = { version = "3.1.2", optional = true }

# Used for the `quantize` feature
half = { version = "2.2.1", optional = true }

[features]
convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"]
convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"]
quantize = ["dep:half"]
2 changes: 2 additions & 0 deletions llama-rs/src/file.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(dead_code)]

use crate::LoadError;
pub use std::fs::File;
pub use std::io::{BufRead, BufReader, BufWriter, Read, Write};
Expand Down
4 changes: 1 addition & 3 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub use ggml::Type as ElementType;

#[cfg(feature = "convert")]
pub mod convert;
#[cfg(feature = "quantize")]
pub mod quantize;

mod file;
Expand Down Expand Up @@ -523,9 +524,6 @@ pub enum LoadError {
/// The path that failed.
path: PathBuf,
},
/// An invalid `itype` was encountered.
#[error("itype supplied was invalid: {0}")]
InvalidItype(u8),
}

#[derive(Error, Debug)]
Expand Down
10 changes: 5 additions & 5 deletions llama-rs/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ pub enum QuantizeProgress<'a> {
pub fn quantize(
file_name_in: impl AsRef<Path>,
file_name_out: impl AsRef<Path>,
itype: u8,
ty: crate::ElementType,
progress_callback: impl Fn(QuantizeProgress),
) -> Result<(), LoadError> {
use crate::file::*;

if itype != 2 && itype != 3 {
return Err(LoadError::InvalidItype(itype));
if !matches!(ty, crate::ElementType::Q4_0 | crate::ElementType::Q4_1) {
todo!("Unsupported quantization format. This should be an error.")
}

let file_in = file_name_in.as_ref();
Expand Down Expand Up @@ -218,7 +218,7 @@ pub fn quantize(
}
}

ftype = itype as u32;
ftype = ty.into();
} else {
// Determines the total bytes were dealing with
let bpe = (nelements * if ftype == 0 { 4 } else { 2 }) as usize;
Expand All @@ -243,7 +243,7 @@ pub fn quantize(

let mut hist_cur = vec![0; 16];

let curr_size = if itype == 2 {
let curr_size = if matches!(ty, crate::ElementType::Q4_0) {
unsafe { quantize_q4_0(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) }
} else {
unsafe { quantize_q4_1(&data_f32, &mut work, nelements, ne[0], &mut hist_cur) }
Expand Down

0 comments on commit 600de36

Please sign in to comment.