Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

fix #149 - load tensors by type, ignoring filetype #152

Merged
merged 4 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml-loader/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ edition = "2021"

[dependencies]
ggml = { path = "../ggml" }
thiserror = "*"
thiserror = "1.0"
28 changes: 18 additions & 10 deletions ggml-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use util::*;

pub type ElementType = ggml::Type;

/// file type containing the model
/// the format of the file containing the model
#[derive(Debug, PartialEq, Clone, Copy)]
#[allow(clippy::upper_case_acronyms)]
pub enum ContainerType {
Expand All @@ -21,7 +21,6 @@ pub enum ContainerType {
/// mmap-able format
GGJT,
}

impl ContainerType {
pub fn support_mmap(&self) -> bool {
match self {
Expand Down Expand Up @@ -64,10 +63,19 @@ pub struct TensorInfo {
pub n_dims: usize,
pub dims: [usize; 2],
pub n_elements: usize,
pub ftype: ElementType,
pub element_type: ElementType,
/// start of tensor - start of file
pub start_offset: u64,
}
impl TensorInfo {
pub fn calc_size(&self) -> usize {
let mut size = ggml::type_size(self.element_type);
for &dim in &self.dims[0..self.n_dims] {
size *= dim;
}
size / ggml::blck_size(self.element_type)
}
}

/// Info in hyperparameter used for later loading tasks. Used in callback.
/// see [`LoadHandler::load_hyper_parameters`]
Expand All @@ -78,10 +86,7 @@ pub struct PartialHyperparameters {

pub enum TensorDataTreatment<'a> {
CopyInto(&'a mut [u8]),
SeekPast {
/// should be `tensor.nbytes`
n_bytes: usize,
},
Skip,
}

#[allow(unused_variables)]
Expand Down Expand Up @@ -173,7 +178,9 @@ pub fn load_weights<T, R: BufRead + Seek>(
// load tensor header
let n_dims: usize = read_i32(reader)?.try_into()?;
let name_len = read_i32(reader)?;
let ftype = decode_element_type_res(read_i32(reader)?)?;
let ftype = read_i32(reader)?;
let ftype =
ggml::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType(ftype))?;

let mut n_elements: usize = 1;
let mut dims = [1usize, 1];
Expand Down Expand Up @@ -214,9 +221,10 @@ pub fn load_weights<T, R: BufRead + Seek>(
dims,
n_dims,
n_elements,
ftype,
element_type: ftype,
start_offset: offset_aligned,
};
let n_bytes = tensor_info.calc_size();

match controlflow_to_result(handler.tensor_buffer(tensor_info))? {
TensorDataTreatment::CopyInto(buf) => {
Expand All @@ -225,7 +233,7 @@ pub fn load_weights<T, R: BufRead + Seek>(
}
reader.read_exact(buf)?;
}
TensorDataTreatment::SeekPast { n_bytes } => {
TensorDataTreatment::Skip => {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?;
}
Expand Down
29 changes: 1 addition & 28 deletions ggml-loader/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub use std::io::{BufRead, Seek, SeekFrom};
use std::ops::ControlFlow;

use crate::{ElementType, LoadError};
use crate::LoadError;

pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> {
let mut bytes = [0u8; N];
Expand Down Expand Up @@ -35,33 +35,6 @@ pub fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error>
reader.fill_buf().map(|b| !b.is_empty())
}

pub fn decode_element_type(ftype: i32) -> Option<ElementType> {
match ftype {
0 => Some(ggml::Type::F32),
1 => Some(ggml::Type::F16),
2 => Some(ggml::Type::Q4_0),
3 => Some(ggml::Type::Q4_1),
_ => None,
}
}

pub fn encode_element_type(element_type: ElementType) -> Option<i32> {
match element_type {
ggml::Type::F32 => Some(0),
ggml::Type::F16 => Some(1),
ggml::Type::Q4_0 => Some(2),
ggml::Type::Q4_1 => Some(3),
_ => None,
}
}

pub fn decode_element_type_res<T>(ftype: i32) -> Result<ElementType, LoadError<T>> {
match decode_element_type(ftype) {
Some(x) => Ok(x),
None => Err(LoadError::UnsupportedElementType(ftype)),
}
}

pub fn controlflow_to_result<A, B>(x: ControlFlow<A, B>) -> Result<B, LoadError<A>> {
match x {
ControlFlow::Continue(x) => Ok(x),
Expand Down
23 changes: 23 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c;
/// The currently-supported format version for `ggml` files.
pub const FORMAT_VERSION: u32 = 1;

/// The size of a `ggml` object.
pub const OBJECT_SIZE: usize = ggml_sys::GGML_OBJECT_SIZE;

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
/// The type of a value in `ggml`.
pub enum Type {
Expand All @@ -32,6 +35,12 @@ pub enum Type {
Q4_0,
/// Quantized 4-bit (type 1); used by GPTQ.
Q4_1,
/// Quantized 4-bit (type 2).
Q4_2,
/// Quantized 4-bit (type 3).
Q4_3,
/// Quantized 8-bit (type 0).
Q8_0,
/// Integer 32-bit.
I32,
/// Float 16-bit.
Expand All @@ -44,6 +53,9 @@ impl From<Type> for ggml_sys::ggml_type {
match t {
Type::Q4_0 => ggml_sys::ggml_type_GGML_TYPE_Q4_0,
Type::Q4_1 => ggml_sys::ggml_type_GGML_TYPE_Q4_1,
Type::Q4_2 => ggml_sys::ggml_type_GGML_TYPE_Q4_2,
Type::Q4_3 => ggml_sys::ggml_type_GGML_TYPE_Q4_3,
Type::Q8_0 => ggml_sys::ggml_type_GGML_TYPE_Q8_0,
Type::I32 => ggml_sys::ggml_type_GGML_TYPE_I32,
Type::F16 => ggml_sys::ggml_type_GGML_TYPE_F16,
Type::F32 => ggml_sys::ggml_type_GGML_TYPE_F32,
Expand All @@ -56,6 +68,9 @@ impl TryFrom<ggml_sys::ggml_type> for Type {
match t {
ggml_sys::ggml_type_GGML_TYPE_Q4_0 => Ok(Type::Q4_0),
ggml_sys::ggml_type_GGML_TYPE_Q4_1 => Ok(Type::Q4_1),
ggml_sys::ggml_type_GGML_TYPE_Q4_2 => Ok(Type::Q4_2),
ggml_sys::ggml_type_GGML_TYPE_Q4_3 => Ok(Type::Q4_3),
ggml_sys::ggml_type_GGML_TYPE_Q8_0 => Ok(Type::Q8_0),
ggml_sys::ggml_type_GGML_TYPE_I32 => Ok(Type::I32),
ggml_sys::ggml_type_GGML_TYPE_F16 => Ok(Type::F16),
ggml_sys::ggml_type_GGML_TYPE_F32 => Ok(Type::F32),
Expand All @@ -68,6 +83,9 @@ impl std::fmt::Display for Type {
match self {
Type::Q4_0 => write!(f, "q4_0"),
Type::Q4_1 => write!(f, "q4_1"),
Type::Q4_2 => write!(f, "q4_2"),
Type::Q4_3 => write!(f, "q4_3"),
Type::Q8_0 => write!(f, "q8_0"),
Type::I32 => write!(f, "i32"),
Type::F16 => write!(f, "f16"),
Type::F32 => write!(f, "f32"),
Expand Down Expand Up @@ -510,6 +528,11 @@ pub struct Tensor {
}

impl Tensor {
/// Size of the `ggml_tensor` struct in bytes.
///
/// Exposed for purposes of determining context size.
pub const C_TYPE_SIZE: usize = std::mem::size_of::<ggml_sys::ggml_tensor>();

/// Creates a shared copy of this tensor pointer.
pub fn share(&self) -> Self {
Tensor {
Expand Down
18 changes: 9 additions & 9 deletions llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,12 @@ pub struct Convert {
pub directory: PathBuf,

/// File type to convert to
#[arg(long, short = 't', value_enum, default_value_t = ElementType::Q4_0)]
pub element_type: ElementType,
#[arg(long, short = 't', value_enum, default_value_t = FileType::Q4_0)]
pub file_type: FileType,
}

#[derive(Parser, Debug, ValueEnum, Clone, Copy)]
pub enum ElementType {
pub enum FileType {
/// Quantized 4-bit (type 0).
Q4_0,
/// Quantized 4-bit (type 1); used by GPTQ.
Expand All @@ -388,13 +388,13 @@ pub enum ElementType {
/// Float 32-bit.
F32,
}
impl From<ElementType> for llama_rs::ElementType {
fn from(t: ElementType) -> Self {
impl From<FileType> for llama_rs::FileType {
fn from(t: FileType) -> Self {
match t {
ElementType::Q4_0 => llama_rs::ElementType::Q4_0,
ElementType::Q4_1 => llama_rs::ElementType::Q4_1,
ElementType::F16 => llama_rs::ElementType::F16,
ElementType::F32 => llama_rs::ElementType::F32,
FileType::Q4_0 => llama_rs::FileType::MostlyQ4_0,
FileType::Q4_1 => llama_rs::FileType::MostlyQ4_1,
FileType::F16 => llama_rs::FileType::MostlyF16,
FileType::F32 => llama_rs::FileType::F32,
}
}
}
2 changes: 1 addition & 1 deletion llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn main() -> Result<()> {
Args::DumpTokens(args) => dump_tokens(&args)?,
Args::Repl(args) => interactive(&args, false)?,
Args::ChatExperimental(args) => interactive(&args, true)?,
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()),
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.file_type.into()),
}

Ok(())
Expand Down
19 changes: 7 additions & 12 deletions llama-rs/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,19 @@ use std::{
vec,
};

use crate::{util, Hyperparameters, Vocabulary};
use ggml_loader::util::encode_element_type;
use crate::{loader_common::FileType, util, Hyperparameters, Vocabulary};

/// Converts a `pth` file to a `ggml` file.
pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) {
pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) {
let tokenizer_path = model_directory.parent().unwrap().join("tokenizer.model");
let vocab = load_vocabulary(tokenizer_path.as_path());

let hparams = load_hyperparameters(model_directory, element_type, &vocab);
let hparams = load_hyperparameters(model_directory, file_type, &vocab);

let model_files = util::find_all_model_files(model_directory).unwrap();

for (i, _file) in model_files.iter().enumerate() {
let fname_out = model_directory.join(format!("rust-model-{element_type}.bin"));
let fname_out = model_directory.join(format!("rust-model-{file_type}.bin"));
let mut file = File::create(fname_out).expect("Unable to create file");
write_header(file.borrow_mut(), &hparams).unwrap();
write_tokens(file.borrow_mut(), &vocab).unwrap();
Expand Down Expand Up @@ -66,11 +65,7 @@ fn load_vocabulary(path: &Path) -> Vocabulary {
}
}

fn load_hyperparameters(
path: &Path,
element_type: ggml::Type,
vocab: &Vocabulary,
) -> Hyperparameters {
fn load_hyperparameters(path: &Path, file_type: FileType, vocab: &Vocabulary) -> Hyperparameters {
#[derive(Deserialize)]
struct HyperParametersJson {
dim: usize,
Expand All @@ -83,7 +78,7 @@ fn load_hyperparameters(
let json = read_to_string(path.join("params.json")).expect("Unable to read file");
let json: HyperParametersJson = serde_json::from_str(&json).expect("Unable to parse json");
Hyperparameters {
element_type,
file_type,
n_ctx: 0,
n_embd: json.dim,
n_head: json.n_heads,
Expand All @@ -107,7 +102,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String
i32::try_from(hparams.n_head).unwrap(),
i32::try_from(hparams.n_layer).unwrap(),
i32::try_from(hparams.n_embd / hparams.n_head).unwrap(),
encode_element_type(hparams.element_type).unwrap(),
hparams.file_type.into(),
];
let mut packed_values: Vec<u8> = vec![];

Expand Down
4 changes: 2 additions & 2 deletions llama-rs/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl InferenceSession {
.map(|(_, tok)| *tok)
.collect();

if self.n_past + prompt_tokens.len() >= model.hparams.n_ctx {
if self.n_past + prompt_tokens.len() >= model.n_ctx() {
return Err(InferenceError::ContextFull);
}

Expand Down Expand Up @@ -96,7 +96,7 @@ impl InferenceSession {
params: &InferenceParameters,
rng: &mut impl rand::Rng,
) -> Result<&'v [u8], InferenceError> {
if self.n_past + 1 >= model.hparams.n_ctx {
if self.n_past + 1 >= model.n_ctx() {
return Err(InferenceError::ContextFull);
}

Expand Down
2 changes: 1 addition & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub use inference_session::{
InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType,
SnapshotError,
};
pub use loader_common::{LoadError, LoadProgress};
pub use loader_common::{FileType, LoadError, LoadProgress};
pub use model::{Hyperparameters, Model};
pub use util::TokenUtf8Buffer;
pub use vocabulary::{TokenBias, TokenId, Vocabulary};
Expand Down
17 changes: 12 additions & 5 deletions llama-rs/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
};

use crate::{
loader_common::FileType,
util::{self, mulf},
LoadError, LoadProgress, Model, TokenId, Vocabulary,
};
Expand Down Expand Up @@ -69,9 +70,9 @@ pub(crate) fn load(
n_head: read_i32(&mut reader)?.try_into()?,
n_layer: read_i32(&mut reader)?.try_into()?,
n_rot: read_i32(&mut reader)?.try_into()?,
element_type: {
file_type: {
let ftype = read_i32(&mut reader)?;
decode_element_type(ftype).ok_or_else(|| LoadError::UnsupportedElementType(ftype))
FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))
}?,
};

Expand Down Expand Up @@ -108,7 +109,13 @@ pub(crate) fn load(
// for the big tensors, we have the option to store the data in 16-bit
// floats or quantized in order to save memory and also to speed up the
// computation
let wtype = hparams.element_type;
let wtype = match hparams.file_type {
FileType::F32 => ggml::Type::F32,
FileType::MostlyF16 => ggml::Type::F16,
FileType::MostlyQ4_0 => ggml::Type::Q4_0,
FileType::MostlyQ4_1 => ggml::Type::Q4_1,
_ => unimplemented!(),
};

let n_embd = hparams.n_embd;
let n_layer = hparams.n_layer;
Expand Down Expand Up @@ -159,7 +166,7 @@ pub(crate) fn load(
(None, None)
};

let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type, mmap);
let mut model = Model::new_loader1(context, hparams, vocabulary, n_ff, wtype, mmap);
match model_type {
ContainerType::GGMF | ContainerType::GGML => {
let file_offset = reader.stream_position()?;
Expand Down Expand Up @@ -421,7 +428,7 @@ fn load_tensor_header_ggmf<'a>(
}

fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option<usize> {
let ftype = decode_element_type(ftype)?;
let ftype = ggml::Type::try_from(ftype).ok()?;
match ftype {
ElementType::Q4_0 | ElementType::Q4_1 => {
assert_eq!(ne[0] % 64, 0);
Expand Down
Loading