Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 9 additions & 29 deletions fuzz/src/array/mask.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::ops::Not;
use std::sync::Arc;

use vortex_array::arrays::{
BoolArray, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray, PrimitiveArray,
StructArray, VarBinViewArray,
};
use vortex_array::validity::Validity;
use vortex_array::vtable::ValidityHelper;
use vortex_array::{ArrayRef, Canonical, IntoArray, ToCanonical};
use vortex_array::{ArrayRef, Canonical, IntoArray};
use vortex_dtype::{ExtDType, match_each_decimal_value_type};
use vortex_error::{VortexResult, VortexUnwrap};
use vortex_mask::{AllOr, Mask};
use vortex_mask::Mask;

/// Apply mask on the canonical form of the array to get a consistent baseline.
/// This implementation manually applies the mask to each canonical type
Expand All @@ -25,11 +23,11 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
array.into_array()
}
Canonical::Bool(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
BoolArray::from_bit_buffer(array.bit_buffer().clone(), new_validity).into_array()
}
Canonical::Primitive(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
PrimitiveArray::from_byte_buffer(
array.byte_buffer().clone(),
array.ptype(),
Expand All @@ -38,14 +36,14 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
.into_array()
}
Canonical::Decimal(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
match_each_decimal_value_type!(array.values_type(), |D| {
DecimalArray::new(array.buffer::<D>(), array.decimal_dtype(), new_validity)
.into_array()
})
}
Canonical::VarBinView(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
VarBinViewArray::new(
array.views().clone(),
array.buffers().clone(),
Expand All @@ -55,7 +53,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
.into_array()
}
Canonical::List(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);

// SAFETY: Since we are only masking the validity and everything else comes from an
// already valid `ListViewArray`, all of the invariants are still upheld.
Expand All @@ -71,7 +69,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
.into_array()
}
Canonical::FixedSizeList(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
FixedSizeListArray::new(
array.elements().clone(),
array.list_size(),
Expand All @@ -81,7 +79,7 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
.into_array()
}
Canonical::Struct(array) => {
let new_validity = apply_mask_to_validity(array.validity(), mask);
let new_validity = array.validity().mask(mask);
StructArray::try_new_with_dtype(
array.fields().clone(),
array.struct_fields().clone(),
Expand Down Expand Up @@ -113,24 +111,6 @@ pub fn mask_canonical_array(canonical: Canonical, mask: &Mask) -> VortexResult<A
})
}

fn apply_mask_to_validity(validity: &Validity, mask: &Mask) -> Validity {
match mask.bit_buffer() {
AllOr::All => Validity::AllInvalid,
AllOr::None => validity.clone(),
AllOr::Some(make_invalid) => match validity {
Validity::NonNullable | Validity::AllValid => {
Validity::Array(BoolArray::from(make_invalid.not()).into_array())
}
Validity::AllInvalid => Validity::AllInvalid,
Validity::Array(is_valid) => {
let is_valid = is_valid.to_bool();
let keep_valid = make_invalid.not();
Validity::from(is_valid.bit_buffer() & &keep_valid)
}
},
}
}

#[cfg(test)]
mod tests {
use vortex_array::arrays::{
Expand Down
Loading