Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 77 additions & 11 deletions vortex-dtype/src/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use itertools::Itertools;
use vortex_error::{
VortexExpect, VortexResult, VortexUnwrap, vortex_bail, vortex_err, vortex_panic,
};
use vortex_utils::aliases::hash_map::HashMap;

use crate::flatbuffers::ViewedDType;
use crate::{DType, FieldName, FieldNames, PType};
Expand Down Expand Up @@ -215,11 +216,79 @@ impl Display for StructFields {
}
}

#[derive(PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Default)]
struct StructFieldsInner {
names: FieldNames,
dtypes: Arc<[FieldDType]>,
// Derived from names, maps from field name to first index.
indices: OnceLock<HashMap<FieldName, usize>>,
}

impl StructFieldsInner {
fn from_fields(names: FieldNames, dtypes: Arc<[FieldDType]>) -> Self {
Self {
names,
dtypes,
indices: OnceLock::new(),
}
}

fn indices(&self) -> &HashMap<FieldName, usize> {
self.indices.get_or_init(|| {
let mut map = HashMap::with_capacity(self.names.len());
for (idx, name) in self.names.iter().enumerate() {
map.entry(name.clone()).or_insert(idx);
}
map
})
}
}

impl PartialEq for StructFieldsInner {
fn eq(&self, other: &Self) -> bool {
self.names == other.names && self.dtypes == other.dtypes
}
}

impl Eq for StructFieldsInner {}

impl Hash for StructFieldsInner {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.names.hash(state);
self.dtypes.hash(state);
}
}

#[cfg(feature = "serde")]
impl serde::Serialize for StructFieldsInner {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;

let mut state = serializer.serialize_struct("StructFieldsInner", 2)?;
state.serialize_field("names", &self.names)?;
state.serialize_field("dtypes", &self.dtypes)?;
state.end()
}
}

#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for StructFieldsInner {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct StructFieldsInnerHelper {
names: FieldNames,
dtypes: Arc<[FieldDType]>,
}

let helper = StructFieldsInnerHelper::deserialize(deserializer)?;
Ok(StructFieldsInner::from_fields(helper.names, helper.dtypes))
}
}

impl Default for StructFields {
Expand All @@ -234,6 +303,7 @@ impl StructFields {
Self(Arc::new(StructFieldsInner {
names: FieldNames::default(),
dtypes: Arc::from([]),
indices: OnceLock::new(),
}))
}

Expand Down Expand Up @@ -266,13 +336,10 @@ impl StructFields {
dtypes.len()
);
}

let inner = Arc::new(StructFieldsInner {
Self(Arc::new(StructFieldsInner::from_fields(
names,
dtypes: dtypes.into(),
});

Self(inner)
dtypes.into(),
)))
}

/// Get the names of the fields in the struct
Expand All @@ -293,8 +360,7 @@ impl StructFields {
/// Find the index of a field by name
/// Returns `None` if the field is not found
pub fn find(&self, name: impl AsRef<str>) -> Option<usize> {
let name = name.as_ref();
self.0.names.iter().position(|n| n.as_ref() == name)
self.0.indices().get(name.as_ref()).copied()
}

/// Get the [`DType`] of a field.
Expand Down
Loading