Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not use serde and bincode crates #10

Merged
merged 6 commits into from
Mar 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@ categories = ["cryptography", "data-structures"]

[dependencies]
turboshake = "=0.4.1"
rayon = "=1.10.0"
rand = "=0.9.0"
rand_chacha = "=0.9.0"
serde = { version = "=1.0.218", features = ["derive"] }
bincode = "=1.3.3"
rayon = "=1.10.0"

[dev-dependencies]
test-case = "=3.3.1"
divan = "=0.1.17"
unicode-xid = "=0.2.6"
test-case = "=3.3.1"

[[bench]]
name = "offline_phase"
Expand Down
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl Client {
query_vec_b[(0, h2 as usize)] = added_val;
}

let query_bytes = query_vec_b.to_bytes()?;
let query_bytes = query_vec_b.to_bytes();
self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });

Ok(query_bytes)
Expand Down Expand Up @@ -189,7 +189,7 @@ impl Client {
query_vec_b[(0, h3 as usize)] = added_val;
}

let query_bytes = query_vec_b.to_bytes()?;
let query_bytes = query_vec_b.to_bytes();
self.pending_queries.insert(key.to_vec(), Query { vec_c: secret_vec_c });

Ok(query_bytes)
Expand Down
54 changes: 49 additions & 5 deletions src/pir_internals/binary_fuse_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ use super::{error::ChalametPIRError, params};
use crate::pir_internals::branch_opt_util;
use rand::prelude::*;
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use turboshake::TurboShake128;

#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Clone, Debug)]
pub struct BinaryFuseFilter {
pub seed: [u8; 32],
pub arity: u32,
Expand Down Expand Up @@ -441,12 +440,57 @@ impl BinaryFuseFilter {
((self.num_fingerprints as f64) * (self.mat_elem_bit_len as f64)) / (self.filter_size as f64)
}

pub fn to_bytes(&self) -> Result<Vec<u8>, ChalametPIRError> {
bincode::serialize(&self).map_err(|err| ChalametPIRError::FailedToSerializeFilterToBytes(err.to_string()))
pub fn to_bytes(&self) -> Vec<u8> {
let offset0 = 0;
let offset1 = offset0 + self.seed.len();
let offset2 = offset1 + std::mem::size_of_val(&self.arity);
let offset3 = offset2 + std::mem::size_of_val(&self.segment_length);
let offset4 = offset3 + std::mem::size_of_val(&self.segment_count_length);
let offset5 = offset4 + std::mem::size_of_val(&self.num_fingerprints);
let offset6 = offset5 + std::mem::size_of_val(&self.filter_size);
let total_byte_len = offset6 + std::mem::size_of_val(&self.mat_elem_bit_len);

let mut bytes = vec![0u8; total_byte_len];

unsafe {
bytes.get_unchecked_mut(offset0..offset1).copy_from_slice(&self.seed);
bytes.get_unchecked_mut(offset1..offset2).copy_from_slice(&self.arity.to_le_bytes());
bytes.get_unchecked_mut(offset2..offset3).copy_from_slice(&self.segment_length.to_le_bytes());
#[rustfmt::skip]
bytes.get_unchecked_mut(offset3..offset4).copy_from_slice(&self.segment_count_length.to_le_bytes());
bytes.get_unchecked_mut(offset4..offset5).copy_from_slice(&self.num_fingerprints.to_le_bytes());
bytes.get_unchecked_mut(offset5..offset6).copy_from_slice(&self.filter_size.to_le_bytes());
bytes.get_unchecked_mut(offset6..).copy_from_slice(&self.mat_elem_bit_len.to_le_bytes());
}

bytes
}

pub fn from_bytes(bytes: &[u8]) -> Result<BinaryFuseFilter, ChalametPIRError> {
bincode::deserialize(bytes).map_err(|err| ChalametPIRError::FailedToDeserializeFilterFromBytes(err.to_string()))
const OFFSET0: usize = 0;
const OFFSET1: usize = OFFSET0 + std::mem::size_of::<[u8; 32]>();
const OFFSET2: usize = OFFSET1 + std::mem::size_of::<u32>();
const OFFSET3: usize = OFFSET2 + std::mem::size_of::<u32>();
const OFFSET4: usize = OFFSET3 + std::mem::size_of::<u32>();
const OFFSET5: usize = OFFSET4 + std::mem::size_of::<usize>();
const OFFSET6: usize = OFFSET5 + std::mem::size_of::<usize>();
const EXPECTED_BYTE_LEN: usize = OFFSET6 + std::mem::size_of::<usize>();

if branch_opt_util::unlikely(EXPECTED_BYTE_LEN != bytes.len()) {
return Err(ChalametPIRError::FailedToDeserializeFilterFromBytes);
}

Ok(unsafe {
BinaryFuseFilter {
seed: bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap(),
arity: u32::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()),
segment_length: u32::from_le_bytes(bytes.get_unchecked(OFFSET2..OFFSET3).try_into().unwrap()),
segment_count_length: u32::from_le_bytes(bytes.get_unchecked(OFFSET3..OFFSET4).try_into().unwrap()),
num_fingerprints: usize::from_le_bytes(bytes.get_unchecked(OFFSET4..OFFSET5).try_into().unwrap()),
filter_size: usize::from_le_bytes(bytes.get_unchecked(OFFSET5..OFFSET6).try_into().unwrap()),
mat_elem_bit_len: usize::from_le_bytes(bytes.get_unchecked(OFFSET6..).try_into().unwrap()),
}
})
}
}

Expand Down
12 changes: 4 additions & 8 deletions src/pir_internals/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@ pub enum ChalametPIRError {
InvalidNumberOfElementsInMatrix,
IncompatibleDimensionForRowVectorTransposedMatrixMultiplication,
InvalidDimensionForVector,
FailedToSerializeMatrixToBytes(String),
FailedToDeserializeMatrixFromBytes(String),
FailedToDeserializeMatrixFromBytes,

// Binary Fuse Filter
EmptyKVDatabase,
ExhaustedAllAttemptsToBuild3WiseXorFilter(usize),
ExhaustedAllAttemptsToBuild4WiseXorFilter(usize),
RowNotDecodable,
DecodedRowNotPrependedWithDigestOfKey,
FailedToSerializeFilterToBytes(String),
FailedToDeserializeFilterFromBytes(String),
FailedToDeserializeFilterFromBytes,

// PIR
KVDatabaseSizeTooLarge,
Expand All @@ -46,8 +44,7 @@ impl Display for ChalametPIRError {
write!(f, "The dimensions are incompatible for multiplication of a row vector and a transposed matrix.")
}
Self::InvalidDimensionForVector => write!(f, "A vector must have either one row or one column."),
Self::FailedToSerializeMatrixToBytes(e) => write!(f, "Matrix serialization failed with: {}", e),
Self::FailedToDeserializeMatrixFromBytes(e) => write!(f, "Matrix deserialization failed with: {}", e),
Self::FailedToDeserializeMatrixFromBytes => write!(f, "Matrix deserialization failed"),

Self::EmptyKVDatabase => write!(f, "Cannot encode empty key-value database."),
Self::ExhaustedAllAttemptsToBuild3WiseXorFilter(max_num_attempts) => {
Expand All @@ -58,8 +55,7 @@ impl Display for ChalametPIRError {
}
Self::RowNotDecodable => write!(f, "Encoded KV database matrix's row cannot be decoded."),
Self::DecodedRowNotPrependedWithDigestOfKey => write!(f, "Decoded row does not have the digest of the key prepended to it."),
Self::FailedToSerializeFilterToBytes(e) => write!(f, "Binary fuse filter serialization failed with: {}", e),
Self::FailedToDeserializeFilterFromBytes(e) => write!(f, "Binary fuse filter deserialization failed with: {}", e),
Self::FailedToDeserializeFilterFromBytes => write!(f, "Binary fuse filter deserialization failed"),

Self::KVDatabaseSizeTooLarge => write!(f, "The key-value database is too large; it can have a maximum of 2^42 entries."),
Self::InvalidHintMatrix => write!(f, "Unexpected number of rows in the hint matrix."),
Expand Down
80 changes: 62 additions & 18 deletions src/pir_internals/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::pir_internals::{
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
ops::{Add, Index, IndexMut, Mul},
Expand All @@ -18,7 +17,7 @@ use std::ops::Neg;

use super::error::ChalametPIRError;

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq)]
pub struct Matrix {
rows: usize,
cols: usize,
Expand Down Expand Up @@ -567,24 +566,69 @@ impl Matrix {
}
}

pub fn to_bytes(&self) -> Result<Vec<u8>, ChalametPIRError> {
bincode::serialize(&self).map_err(|e| ChalametPIRError::FailedToSerializeMatrixToBytes(e.to_string()))
pub fn to_bytes(&self) -> Vec<u8> {
let encoded_elems_byte_len = std::mem::size_of::<u32>() * self.rows * self.cols;

let offset0 = 0;
let offset1 = offset0 + std::mem::size_of_val(&self.rows);
let offset2 = offset1 + std::mem::size_of_val(&self.cols);
let total_byte_len = offset2 + encoded_elems_byte_len;

let elems_as_bytes = unsafe {
let ptr_elems = self.elems.as_ptr();
let ptr_elem_bytes: *const u8 = ptr_elems.cast();

core::slice::from_raw_parts(ptr_elem_bytes, encoded_elems_byte_len)
};

let mut bytes = vec![0u8; total_byte_len];

unsafe {
bytes.get_unchecked_mut(offset0..offset1).copy_from_slice(&self.rows.to_le_bytes());
bytes.get_unchecked_mut(offset1..offset2).copy_from_slice(&self.cols.to_le_bytes());
bytes.get_unchecked_mut(offset2..).copy_from_slice(elems_as_bytes);
}

bytes
}

pub fn from_bytes(bytes: &[u8]) -> Result<Matrix, ChalametPIRError> {
bincode::deserialize(bytes).map_or_else(
|e| Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes(e.to_string())),
|v: Matrix| {
let expected_num_elems = v.num_rows() * v.num_cols();
let actual_num_elems = v.num_elems();

if branch_opt_util::likely(expected_num_elems == actual_num_elems) {
Ok(v)
} else {
Err(ChalametPIRError::InvalidNumberOfElementsInMatrix)
}
},
)
const OFFSET0: usize = 0;
const OFFSET1: usize = OFFSET0 + std::mem::size_of::<usize>();
const OFFSET2: usize = OFFSET1 + std::mem::size_of::<usize>();

if branch_opt_util::unlikely(bytes.len() <= OFFSET2) {
return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes);
}

let (rows, cols) = unsafe {
(
usize::from_le_bytes(bytes.get_unchecked(OFFSET0..OFFSET1).try_into().unwrap()),
usize::from_le_bytes(bytes.get_unchecked(OFFSET1..OFFSET2).try_into().unwrap()),
)
};
let num_elems = rows * cols;

if branch_opt_util::unlikely(num_elems == 0) {
return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes);
}

let encoded_elems_byte_len = std::mem::size_of::<u32>() * num_elems;
let remaining_num_bytes = bytes.len() - OFFSET2;

if branch_opt_util::unlikely(encoded_elems_byte_len != remaining_num_bytes) {
return Err(ChalametPIRError::FailedToDeserializeMatrixFromBytes);
}

let elems = unsafe {
let ptr_elem_bytes = bytes[OFFSET2..].as_ptr();
let ptr_elems: *const u32 = ptr_elem_bytes.cast();

core::slice::from_raw_parts(ptr_elems, num_elems)
}
.to_vec();

Ok(Matrix { rows, cols, elems })
}
}

Expand Down Expand Up @@ -982,7 +1026,7 @@ pub mod test {
let num_cols = rng.random_range(MIN_MATRIX_DIM..=MAX_MATRIX_DIM);

let matrix_a = Matrix::generate_from_seed(num_rows, num_cols, &seed).expect("Matrix must be generated from seed");
let matrix_a_bytes = matrix_a.to_bytes().unwrap();
let matrix_a_bytes = matrix_a.to_bytes();
let matrix_b = Matrix::from_bytes(&matrix_a_bytes).unwrap();

assert_eq!(matrix_a, matrix_b);
Expand Down
6 changes: 3 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ impl Server {
let pub_mat_a = unsafe { Matrix::generate_from_seed(pub_mat_a_num_rows, pub_mat_a_num_cols, seed_μ).unwrap_unchecked() };

let hint_mat_m = unsafe { (&pub_mat_a * &parsed_db_mat_d).unwrap_unchecked() };
let hint_bytes = hint_mat_m.to_bytes()?;
let filter_param_bytes: Vec<u8> = filter.to_bytes()?;
let hint_bytes = hint_mat_m.to_bytes();
let filter_param_bytes: Vec<u8> = filter.to_bytes();
let transposed_parsed_db_mat_d = parsed_db_mat_d.transpose();

Ok((Server { transposed_parsed_db_mat_d }, hint_bytes, filter_param_bytes))
Expand All @@ -81,7 +81,7 @@ impl Server {
let query_vector = Matrix::from_bytes(query)?;
let response_vector = query_vector.row_vector_x_transposed_matrix(&self.transposed_parsed_db_mat_d)?;

response_vector.to_bytes()
Ok(response_vector.to_bytes())
}

/// This is required to ensure that LWE PIR protocol is correct. See eq. 8 in section 5.1 of the FrodoPIR paper @ https://ia.cr/2022/981.
Expand Down