Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zstd support #129

Merged
merged 4 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ log = "0.4.18"
env_logger = "0.10.0"
rustc-hash = "1.1.0"
half = "2.3.1"
zstd = "0.13.1"

[build-dependencies]
cbindgen = "0.23.0"
Expand Down
55 changes: 29 additions & 26 deletions src/block_ffm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use crate::model_instance;
use crate::optimizer;
use crate::port_buffer;
use crate::port_buffer::PortBuffer;
use crate::regressor;
use crate::quantization;
use crate::regressor;
use crate::regressor::{BlockCache, FFM_CONTRA_BUF_LEN};

const FFM_STACK_BUF_LEN: usize = 170393;
Expand Down Expand Up @@ -458,8 +458,11 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
contra_fields,
features_present,
ffm,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockFFMCache, executing forward pass without cache");
} = next_cache
else {
log::warn!(
"Unable to downcast cache to BlockFFMCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -667,15 +670,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockFFMCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockFFMCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::FFM {
contra_fields,
features_present,
ffm,
} = next_cache else {
} = next_cache
else {
log::warn!("Unable to downcast cache to BlockFFMCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -829,32 +835,29 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {

let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
if use_quantization {
let quantized_weights = quantization::quantize_ffm_weights(&self.weights);
block_helpers::write_weights_to_buf(&quantized_weights, output_bufwriter, false)?;
} else {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
}
}
block_helpers::write_weights_to_buf(&self.optimizer, output_bufwriter, false)?;
Ok(())
}

fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut self.weights);
} else {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
}
}

block_helpers::read_weights_from_buf(&mut self.optimizer, input_bufreader, false)?;
Ok(())
}
Expand All @@ -877,18 +880,18 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockFFM<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
use_quantization: bool
use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
.downcast_mut::<BlockFFM<optimizer::OptimizerSGD>>()
.unwrap();

if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
if use_quantization {
quantization::dequantize_ffm_weights(input_bufreader, &mut forward.weights);
} else {
block_helpers::read_weights_from_buf(&mut forward.weights, input_bufreader, false)?;
}
}
block_helpers::skip_weights_from_buf::<OptimizerData<L>>(
self.ffm_weights_len as usize,
input_bufreader,
Expand Down Expand Up @@ -1937,7 +1940,7 @@ mod tests {
contra_field_index: mi.ffm_k,
}]);
assert_eq!(spredict2(&mut bg, &fb, &mut pb), 0.5);
assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.5);
assert_eq!(slearn2(&mut bg, &fb, &mut pb, true), 0.62245935);
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions src/block_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro_rules! assert_epsilon {
pub fn read_weights_from_buf<L>(
weights: &mut Vec<L>,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
return Err("Loading weights to unallocated weighs buffer".to_string())?;
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn skip_weights_from_buf<L>(
pub fn write_weights_to_buf<L>(
weights: &Vec<L>,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
if weights.is_empty() {
assert!(false);
Expand Down
24 changes: 11 additions & 13 deletions src/block_lr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,10 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
return;
};

let BlockCache::LR {
lr,
combo_indexes,
} = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, executing forward pass without cache");
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!(
"Unable to downcast cache to BlockLRCache, executing forward pass without cache"
);
self.forward(further_blocks, fb, pb);
return;
};
Expand Down Expand Up @@ -222,14 +221,13 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
caches: &mut [BlockCache],
) {
let Some((next_cache, further_caches)) = caches.split_first_mut() else {
log::warn!("Expected BlockLRCache caches, but non available, skipping cache preparation");
log::warn!(
"Expected BlockLRCache caches, but non available, skipping cache preparation"
);
return;
};

let BlockCache::LR {
lr,
combo_indexes
} = next_cache else {
let BlockCache::LR { lr, combo_indexes } = next_cache else {
log::warn!("Unable to downcast cache to BlockLRCache, skipping cache preparation");
return;
};
Expand Down Expand Up @@ -263,15 +261,15 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)
}

fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)
}
Expand All @@ -280,7 +278,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockLR<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
6 changes: 3 additions & 3 deletions src/block_neural.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn write_weights_to_buf(
&self,
output_bufwriter: &mut dyn io::Write,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false)?;
block_helpers::write_weights_to_buf(&self.weights_optimizer, output_bufwriter, false)?;
Expand All @@ -440,7 +440,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
fn read_weights_from_buf(
&mut self,
input_bufreader: &mut dyn io::Read,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false)?;
block_helpers::read_weights_from_buf(&mut self.weights_optimizer, input_bufreader, false)?;
Expand All @@ -466,7 +466,7 @@ impl<L: OptimizerTrait + 'static> BlockTrait for BlockNeuronLayer<L> {
&self,
input_bufreader: &mut dyn io::Read,
forward: &mut Box<dyn BlockTrait>,
_use_quantization: bool
_use_quantization: bool,
) -> Result<(), Box<dyn Error>> {
let forward = forward
.as_any()
Expand Down
138 changes: 138 additions & 0 deletions src/buffer_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use flate2::read::MultiGzDecoder;
use std::fs::File;
use std::io;
use std::io::BufRead;
use std::path::Path;
use zstd::stream::read::Decoder as ZstdDecoder;

pub fn create_buffered_input(input_filename: &str) -> Box<dyn BufRead> {
// Handler for different (or no) compression types

let input = File::open(input_filename).expect("Could not open the input file.");

let input_format = Path::new(&input_filename)
.extension()
.and_then(|ext| ext.to_str())
.expect("Failed to get the file extension.");

match input_format {
"gz" => {
let gz_decoder = MultiGzDecoder::new(input);
let reader = io::BufReader::new(gz_decoder);
Box::new(reader)
}
"zst" => {
let zstd_decoder = ZstdDecoder::new(input).unwrap();
let reader = io::BufReader::new(zstd_decoder);
Box::new(reader)
}
"vw" => {
let reader = io::BufReader::new(input);
Box::new(reader)
}
_ => {
panic!("Please specify a valid input format (.vw, .zst, .gz)");
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::{self, Read, Write};
use tempfile::Builder as TempFileBuilder;
use tempfile::NamedTempFile;
use zstd::stream::Encoder as ZstdEncoder;

fn create_temp_file_with_contents(
extension: &str,
contents: &[u8],
) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new()
.suffix(&format!(".{}", extension))
.tempfile()?;
temp_file.as_file().write_all(contents)?;
Ok(temp_file)
}

fn create_gzipped_temp_file(contents: &[u8]) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new().suffix(".gz").tempfile()?;
let gz = GzEncoder::new(Vec::new(), Compression::default());
let mut gz_writer = io::BufWriter::new(gz);
gz_writer.write_all(contents)?;
let gz = gz_writer.into_inner()?.finish()?;
temp_file.as_file().write_all(&gz)?;
Ok(temp_file)
}

fn create_zstd_temp_file(contents: &[u8]) -> io::Result<NamedTempFile> {
let temp_file = TempFileBuilder::new().suffix(".zst").tempfile()?;
let mut zstd_encoder = ZstdEncoder::new(Vec::new(), 1)?;
zstd_encoder.write_all(contents)?;
let encoded_data = zstd_encoder.finish()?;
temp_file.as_file().write_all(&encoded_data)?;
Ok(temp_file)
}

// Test for uncompressed file ("vw" extension)
#[test]
fn test_uncompressed_file() {
let contents = b"Sample text for uncompressed file.";
let temp_file =
create_temp_file_with_contents("vw", contents).expect("Failed to create temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(
buffer, contents,
"Contents did not match for uncompressed file."
);
}

// Test for gzipped files ("gz" extension)
#[test]
fn test_gz_compressed_file() {
let contents = b"Sample text for gzipped file.";
let temp_file =
create_gzipped_temp_file(contents).expect("Failed to create gzipped temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(buffer, contents, "Contents did not match for gzipped file.");
}

// Test for zstd compressed files ("zst" extension)
#[test]
fn test_zstd_compressed_file() {
let contents = b"Sample text for zstd compressed file.";
let temp_file = create_zstd_temp_file(contents).expect("Failed to create zstd temp file");
let mut reader = create_buffered_input(temp_file.path().to_str().unwrap());

let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.expect("Failed to read from the reader");
assert_eq!(
buffer, contents,
"Contents did not match for zstd compressed file."
);
}

// Test for unsupported file format
#[test]
#[should_panic(expected = "Please specify a valid input format (.vw, .zst, .gz)")]
fn test_unsupported_file_format() {
let contents = b"Some content";
let temp_file =
create_temp_file_with_contents("txt", contents).expect("Failed to create temp file");
let _reader = create_buffered_input(temp_file.path().to_str().unwrap());
}
}
1 change: 0 additions & 1 deletion src/feature_transform_implementations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::feature_transform_executor::{
use crate::feature_transform_parser;
use crate::vwmap::{NamespaceDescriptor, NamespaceFormat, NamespaceType};


// -------------------------------------------------------------------
// TransformerBinner - A basic binner
// It can take any function as a binning function f32 -> f32. Then output is rounded to integer
Expand Down
Loading
Loading