Skip to content

Commit

Permalink
Merge pull request #186 from krojew/vector-support-fixed
Browse files Browse the repository at this point in the history
Vector support fixed
  • Loading branch information
krojew authored Jul 31, 2024
2 parents d565d83 + 1235053 commit 628a5ad
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 20 deletions.
143 changes: 125 additions & 18 deletions cassandra-protocol/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ macro_rules! query_values {
};
}

macro_rules! vector_as_rust {
(f32) => {
impl AsRustType<Vec<f32>> for Vector {
fn as_rust_type(&self) -> Result<Option<Vec<f32>>> {
let mut result: Vec<f32> = Vec::new();
for data_value in &self.data {
let float = decode_float(data_value.as_slice().unwrap_or(Err(
Error::General(format!("Failed to convert {:?} into float", data_value)),
)?))?;
result.push(float);
}

Ok(Some(result))
}
}
};
}

macro_rules! list_as_rust {
(List) => (
impl AsRustType<Vec<List>> for List {
Expand Down Expand Up @@ -177,10 +195,10 @@ macro_rules! list_as_cassandra_type {
| Some(ColTypeOptionValue::CSet(ref type_option)) => {
let type_option_ref = type_option.deref().clone();
let wrapper = wrapper_fn(&type_option_ref.id);
let convert = self.map(|bytes| {
wrapper(bytes, &type_option_ref, protocol_version).unwrap()
});
Ok(Some(CassandraType::List(convert)))
let convert = self
.try_map(|bytes| wrapper(bytes, &type_option_ref, protocol_version));

convert.map(|convert| Some(CassandraType::List(convert)))
}
_ => Err(Error::General(format!(
"Invalid conversion. \
Expand All @@ -193,6 +211,61 @@ macro_rules! list_as_cassandra_type {
};
}

macro_rules! vector_as_cassandra_type {
() => {
impl crate::types::AsCassandraType for Vector {
fn as_cassandra_type(
&self,
) -> Result<Option<crate::types::cassandra_type::CassandraType>> {
use crate::error::Error;
use crate::types::cassandra_type::wrapper_fn;
use crate::types::cassandra_type::CassandraType;

let protocol_version = self.protocol_version;

match &self.metadata {
ColTypeOption {
id: ColType::Custom,
value,
} => {
if let Some(value) = value {
let VectorInfo { internal_type, .. } = get_vector_type_info(value)?;

if internal_type == "FloatType" {
let internal_type_option = ColTypeOption {
id: ColType::Float,
value: None,
};

let wrapper = wrapper_fn(&ColType::Float);

let convert = self.try_map(|bytes| {
wrapper(bytes, &internal_type_option, protocol_version)
});

return convert.map(|convert| Some(CassandraType::Vector(convert)));
} else {
return Err(Error::General(format!(
"Invalid conversion. \
Cannot convert Vector<{:?}> into Vector (valid types: Vector<FloatType>",
internal_type
)));
}
} else {
return Err(Error::General("Custom type string is none".to_string()));
}
}
_ => Err(Error::General(format!(
"Invalid conversion. \
Cannot convert {:?} into Vector (valid types: Custom).",
self.metadata.value
))),
}
}
}
};
}

macro_rules! map_as_cassandra_type {
() => {
impl crate::types::AsCassandraType for Map {
Expand All @@ -201,6 +274,7 @@ macro_rules! map_as_cassandra_type {
) -> Result<Option<crate::types::cassandra_type::CassandraType>> {
use crate::types::cassandra_type::wrapper_fn;
use crate::types::cassandra_type::CassandraType;
use itertools::Itertools;
use std::ops::Deref;

if let Some(ColTypeOptionValue::CMap(
Expand All @@ -217,21 +291,21 @@ macro_rules! map_as_cassandra_type {

let protocol_version = self.protocol_version;

let map = self
return self
.data
.iter()
.map(|(key, value)| {
(
key_wrapper(key, &key_col_type_option, protocol_version).unwrap(),
value_wrapper(value, &value_col_type_option, protocol_version)
.unwrap(),
key_wrapper(key, &key_col_type_option, protocol_version).and_then(
|key| {
value_wrapper(value, &value_col_type_option, protocol_version)
.map(|value| (key, value))
},
)
})
.collect::<Vec<(CassandraType, CassandraType)>>();

return Ok(Some(CassandraType::Map(map)));
.try_collect()
.map(|map| Some(CassandraType::Map(map)));
} else {
panic!("not amap")
panic!("not a map")
}
}
}
Expand All @@ -246,16 +320,17 @@ macro_rules! tuple_as_cassandra_type {
) -> Result<Option<crate::types::cassandra_type::CassandraType>> {
use crate::types::cassandra_type::wrapper_fn;
use crate::types::cassandra_type::CassandraType;
use itertools::Itertools;

let protocol_version = self.protocol_version;
let values = self
.data
.iter()
.map(|(col_type, bytes)| {
let wrapper = wrapper_fn(&col_type.id);
wrapper(&bytes, col_type, protocol_version).unwrap()
wrapper(&bytes, col_type, protocol_version)
})
.collect();
.try_collect()?;

Ok(Some(CassandraType::Tuple(values)))
}
Expand All @@ -276,11 +351,11 @@ macro_rules! udt_as_cassandra_type {
let mut map = HashMap::with_capacity(self.data.len());
let protocol_version = self.protocol_version;

self.data.iter().for_each(|(key, (col_type, bytes))| {
for (key, (col_type, bytes)) in &self.data {
let wrapper = wrapper_fn(&col_type.id);
let value = wrapper(&bytes, col_type, protocol_version).unwrap();
let value = wrapper(&bytes, col_type, protocol_version)?;
map.insert(key.clone(), value);
});
}

Ok(Some(CassandraType::Udt(map)))
}
Expand Down Expand Up @@ -554,6 +629,19 @@ macro_rules! into_rust_by_name {
}
}
);
(Row, Vector) => (
impl IntoRustByName<Vector> for Row {
fn get_by_name(&self, name: &str) -> Result<Option<Vector>> {
let protocol_version = self.protocol_version;
self.col_spec_by_name(name)
.ok_or(column_is_empty_err(name))
.and_then(|(col_spec, cbytes)| {
let col_type = &col_spec.col_type;
as_rust_type!(col_type, cbytes, protocol_version, Vector)
})
}
}
);
(Row, Map) => (
impl IntoRustByName<Map> for Row {
fn get_by_name(&self, name: &str) -> Result<Option<Map>> {
Expand Down Expand Up @@ -1275,6 +1363,25 @@ macro_rules! as_rust_type {
))),
}
};
($data_type_option:ident, $data_value:ident, $version:ident, Vector) => {
match $data_type_option.id {
ColType::Custom => match $data_value.as_slice() {
Some(ref bytes) => {
let crate::types::vector::VectorInfo { internal_type: _, count } = crate::types::vector::get_vector_type_info($data_type_option.value.as_ref()?)?;

decode_float_vector(bytes, $version, count)
.map(|data| Some(Vector::new($data_type_option.clone(), data, $version)))
.map_err(Into::into)
},
None => Ok(None),
},
_ => Err(crate::error::Error::(format!(
"Invalid conversion. \
Cannot convert {:?} into Vector (valid types: Custom).",
$data_type_option.id
))),
}
};
($data_type_option:ident, $data_value:ident, $version:ident, Map) => {
match $data_type_option.id {
ColType::Map => match $data_value.as_slice() {
Expand Down
1 change: 1 addition & 0 deletions cassandra-protocol/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod rows;
pub mod tuple;
pub mod udt;
pub mod value;
pub mod vector;

pub mod prelude {
pub use crate::error::{Error, Result};
Expand Down
31 changes: 30 additions & 1 deletion cassandra-protocol/src/types/cassandra_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub enum CassandraType {
Set(Vec<CassandraType>),
Udt(HashMap<String, CassandraType>),
Tuple(Vec<CassandraType>),
Vector(Vec<CassandraType>),
Null,
}

Expand All @@ -48,7 +49,7 @@ pub fn wrapper_fn(
ColType::Ascii => &wrappers::ascii,
ColType::Int => &wrappers::int,
ColType::List => &wrappers::list,
ColType::Custom => &|_, _, _| Err("Conversion into custom types is not supported!".into()),
ColType::Custom => &wrappers::custom,
ColType::Bigint => &wrappers::bigint,
ColType::Boolean => &wrappers::bool,
ColType::Counter => &wrappers::counter,
Expand Down Expand Up @@ -80,10 +81,38 @@ pub mod wrappers {
use crate::frame::Version;
use crate::types::data_serialization_types::*;
use crate::types::list::List;
use crate::types::vector::{get_vector_type_info, Vector, VectorInfo};
use crate::types::AsCassandraType;
use crate::types::CBytes;
use crate::types::{map::Map, tuple::Tuple, udt::Udt};

pub fn custom(
bytes: &CBytes,
col_type: &ColTypeOption,
version: Version,
) -> CDRSResult<CassandraType> {
if let ColTypeOption {
id: ColType::Custom,
value: Some(value),
} = col_type
{
let VectorInfo {
internal_type: _,
count,
} = get_vector_type_info(value)?;

if let Some(actual_bytes) = bytes.as_slice() {
let vector = decode_float_vector(actual_bytes, version, count)
.map(|data| Vector::new(col_type.clone(), data, version))?
.as_cassandra_type()?
.unwrap_or(CassandraType::Null);
return Ok(vector);
}
}

Ok(CassandraType::Null)
}

pub fn map(
bytes: &CBytes,
col_type: &ColTypeOption,
Expand Down
15 changes: 15 additions & 0 deletions cassandra-protocol/src/types/data_serialization_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ pub fn decode_list(bytes: &[u8], version: Version) -> Result<Vec<CBytes>, io::Er
Ok(list)
}

pub fn decode_float_vector(
bytes: &[u8],
_version: Version,
count: usize,
) -> Result<Vec<CBytes>, io::Error> {
let type_size = 4;

let mut vector = Vec::with_capacity(count);
for i in (0..(count * type_size)).step_by(4) {
vector.push(CBytes::new(bytes[i..i + type_size].to_vec()));
}

Ok(vector)
}

// Decodes Cassandra `set` data (bytes)
#[inline]
pub fn decode_set(bytes: &[u8], version: Version) -> Result<Vec<CBytes>, io::Error> {
Expand Down
8 changes: 8 additions & 0 deletions cassandra-protocol/src/types/list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use derive_more::Constructor;
use itertools::Itertools;
use num_bigint::BigInt;
use std::net::IpAddr;
use uuid::Uuid;
Expand Down Expand Up @@ -31,6 +32,13 @@ impl List {
{
self.data.iter().map(f).collect()
}

fn try_map<T, F>(&self, f: F) -> Result<Vec<T>>
where
F: FnMut(&CBytes) -> Result<T>,
{
self.data.iter().map(f).try_collect()
}
}

impl AsRust for List {}
Expand Down
69 changes: 69 additions & 0 deletions cassandra-protocol/src/types/vector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use crate::error::{Error, Result};
use crate::frame::message_result::{ColType, ColTypeOption, ColTypeOptionValue};
use crate::frame::Version;
use crate::types::data_serialization_types::*;
use crate::types::{AsRust, AsRustType, CBytes};
use derive_more::Constructor;
use itertools::Itertools;

// TODO: consider using pointers to ColTypeOption and Vec<CBytes> instead of owning them.
#[derive(Debug, Constructor)]
pub struct Vector {
/// column spec of the list, i.e. id should be List as it's a list and value should contain
/// a type of list items.
metadata: ColTypeOption,
data: Vec<CBytes>,
protocol_version: Version,
}

impl Vector {
fn try_map<T, F>(&self, f: F) -> Result<Vec<T>>
where
F: FnMut(&CBytes) -> Result<T>,
{
self.data.iter().map(f).try_collect()
}
}

pub struct VectorInfo {
pub internal_type: String,
pub count: usize,
}

pub fn get_vector_type_info(option_value: &ColTypeOptionValue) -> Result<VectorInfo> {
let input = match option_value {
ColTypeOptionValue::CString(ref s) => s,
_ => return Err(Error::General("Option value must be a string!".into())),
};

let _custom_type = input.split('(').next().unwrap().rsplit('.').next().unwrap();

let vector_type = input
.split('(')
.nth(1)
.and_then(|s| s.split(',').next())
.and_then(|s| s.rsplit('.').next())
.map(|s| s.trim())
.ok_or_else(|| Error::General("Cannot parse vector type!".into()))?;

let count: usize = input
.split('(')
.nth(1)
.and_then(|s| s.rsplit(',').next())
.and_then(|s| s.split(')').next())
.map(|s| s.trim().parse())
.transpose()
.map_err(|_| Error::General("Cannot parse vector count!".to_string()))?
.ok_or_else(|| Error::General("Cannot parse vector count!".into()))?;

Ok(VectorInfo {
internal_type: vector_type.to_string(),
count,
})
}

impl AsRust for Vector {}

vector_as_rust!(f32);

vector_as_cassandra_type!();
Loading

0 comments on commit 628a5ad

Please sign in to comment.