Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion encodings/alp/src/alp/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ impl EncodeVTable<ALPVTable> for ALPVTable {
impl VisitorVTable<ALPVTable> for ALPVTable {
fn visit_buffers(_array: &ALPArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &ALPArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a ALPArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("encoded", array.encoded());
if let Some(patches) = array.patches() {
visitor.visit_patches(patches);
Expand Down
2 changes: 1 addition & 1 deletion encodings/alp/src/alp_rd/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ impl EncodeVTable<ALPRDVTable> for ALPRDVTable {
impl VisitorVTable<ALPRDVTable> for ALPRDVTable {
fn visit_buffers(_array: &ALPRDArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &ALPRDArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a ALPRDArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("left_parts", array.left_parts());
visitor.visit_child("right_parts", array.right_parts());
if let Some(patches) = array.left_parts_patches() {
Expand Down
2 changes: 1 addition & 1 deletion encodings/bytebool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl VisitorVTable<ByteBoolVTable> for ByteBoolVTable {
visitor.visit_buffer(array.buffer());
}

fn visit_children(array: &ByteBoolArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a ByteBoolArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_validity(array.validity(), array.len());
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ impl EncodeVTable<DateTimePartsVTable> for DateTimePartsVTable {
impl VisitorVTable<DateTimePartsVTable> for DateTimePartsVTable {
fn visit_buffers(_array: &DateTimePartsArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &DateTimePartsArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a DateTimePartsArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("days", array.days());
visitor.visit_child("seconds", array.seconds());
visitor.visit_child("subseconds", array.subseconds());
Expand Down
5 changes: 4 additions & 1 deletion encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ impl ValidityChild<DecimalBytePartsVTable> for DecimalBytePartsVTable {
impl VisitorVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
fn visit_buffers(_array: &DecimalBytePartsArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &DecimalBytePartsArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(
array: &'a DecimalBytePartsArray,
visitor: &mut dyn ArrayChildVisitor<'a>,
) {
visitor.visit_child("msp", &array.msp);
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/bitpacking/vtable/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ impl VisitorVTable<BitPackedVTable> for BitPackedVTable {
visitor.visit_buffer(array.packed());
}

fn visit_children(array: &BitPackedArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a BitPackedArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
if let Some(patches) = array.patches() {
visitor.visit_patches(patches);
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/delta/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ impl CanonicalVTable<DeltaVTable> for DeltaVTable {
impl VisitorVTable<DeltaVTable> for DeltaVTable {
fn visit_buffers(_array: &DeltaArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &DeltaArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a DeltaArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("bases", array.bases());
visitor.visit_child("deltas", array.deltas());
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/for/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ impl EncodeVTable<FoRVTable> for FoRVTable {
impl VisitorVTable<FoRVTable> for FoRVTable {
fn visit_buffers(_array: &FoRArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &FoRArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a FoRArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("encoded", array.encoded())
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/rle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ impl VisitorVTable<RLEVTable> for RLEVTable {
// RLE stores all data in child arrays, no direct buffers
}

fn visit_children(array: &RLEArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a RLEArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("values", array.values());
visitor.visit_child("indices", array.indices());
visitor.visit_child("values_idx_offsets", array.values_idx_offsets());
Expand Down
2 changes: 1 addition & 1 deletion encodings/fsst/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ impl VisitorVTable<FSSTVTable> for FSSTVTable {
visitor.visit_buffer(&array.symbol_lengths().clone().into_byte_buffer());
}

fn visit_children(array: &FSSTArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a FSSTArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("codes", array.codes().as_ref());
visitor.visit_child("uncompressed_lengths", array.uncompressed_lengths());
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/pco/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ impl VisitorVTable<PcoVTable> for PcoVTable {
}
}

fn visit_children(array: &PcoArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a PcoArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_validity(&array.unsliced_validity, array.unsliced_n_rows());
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/runend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl EncodeVTable<RunEndVTable> for RunEndVTable {
impl VisitorVTable<RunEndVTable> for RunEndVTable {
fn visit_buffers(_array: &RunEndArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &RunEndArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a RunEndArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("ends", array.ends());
visitor.visit_child("values", array.values());
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/sequence/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl VisitorVTable<SequenceVTable> for SequenceVTable {
// TODO(joe): expose scalar values
}

fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
fn visit_children<'a>(_array: &'a SequenceArray, _visitor: &mut dyn ArrayChildVisitor<'a>) {}
}

#[derive(Clone, Debug)]
Expand Down
2 changes: 1 addition & 1 deletion encodings/sparse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ impl VisitorVTable<SparseVTable> for SparseVTable {
visitor.visit_buffer(&fill_value_buffer);
}

fn visit_children(array: &SparseArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a SparseArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_patches(array.patches())
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/zigzag/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ impl EncodeVTable<ZigZagVTable> for ZigZagVTable {
impl VisitorVTable<ZigZagVTable> for ZigZagVTable {
fn visit_buffers(_array: &ZigZagArray, _visitor: &mut dyn ArrayBufferVisitor) {}

fn visit_children(array: &ZigZagArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a ZigZagArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_child("encoded", array.encoded())
}
}
Expand Down
2 changes: 1 addition & 1 deletion encodings/zstd/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ impl VisitorVTable<ZstdVTable> for ZstdVTable {
}
}

fn visit_children(array: &ZstdArray, visitor: &mut dyn ArrayChildVisitor) {
fn visit_children<'a>(array: &'a ZstdArray, visitor: &mut dyn ArrayChildVisitor<'a>) {
visitor.visit_validity(&array.unsliced_validity, array.unsliced_n_rows());
}
}
2 changes: 1 addition & 1 deletion vortex-array/src/array/display/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl<'a, 'b: 'a> TreeFormatter<'a, 'b> {
.into_iter()
.zip(array.children().into_iter())
{
i.format(&name, child)?;
i.format(&name, child.to_array())?;
}
Ok(())
})?;
Expand Down
27 changes: 17 additions & 10 deletions vortex-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod display;
mod operator;
pub mod session;
pub mod transform;
mod traversal;
mod visitor;

use std::any::Any;
Expand Down Expand Up @@ -647,24 +648,30 @@ impl<V: VTable> ArrayEq for ArrayAdapter<V> {
}

impl<V: VTable> ArrayVisitor for ArrayAdapter<V> {
fn children(&self) -> Vec<ArrayRef> {
struct ChildrenCollector {
children: Vec<ArrayRef>,
fn children(&self) -> Vec<&dyn Array> {
struct ChildrenCollector<'a> {
children: Vec<&'a dyn Array>,
}

impl ArrayChildVisitor for ChildrenCollector {
fn visit_child(&mut self, _name: &str, array: &dyn Array) {
self.children.push(array.to_array());
impl<'a> ArrayChildVisitor<'a> for ChildrenCollector<'a> {
fn visit_child(&mut self, _name: &str, array: &'a dyn Array) {
self.children.push(array);
}
}

let mut collector = ChildrenCollector {
children: Vec::new(),
};
<V::VisitorVTable as VisitorVTable<V>>::visit_children(&self.0, &mut collector);

collector.children
}

fn visit_children<'a>(&'a self, visitor: &mut dyn ArrayChildVisitor<'a>) {
// Directly call the vtable's visit_children - zero allocation!
<V::VisitorVTable as VisitorVTable<V>>::visit_children(&self.0, visitor)
}

fn nchildren(&self) -> usize {
<V::VisitorVTable as VisitorVTable<V>>::nchildren(&self.0)
}
Expand All @@ -674,8 +681,8 @@ impl<V: VTable> ArrayVisitor for ArrayAdapter<V> {
names: Vec<String>,
}

impl ArrayChildVisitor for ChildNameCollector {
fn visit_child(&mut self, name: &str, _array: &dyn Array) {
impl<'a> ArrayChildVisitor<'a> for ChildNameCollector {
fn visit_child(&mut self, name: &str, _array: &'a dyn Array) {
self.names.push(name.to_string());
}
}
Expand All @@ -690,8 +697,8 @@ impl<V: VTable> ArrayVisitor for ArrayAdapter<V> {
children: Vec<(String, ArrayRef)>,
}

impl ArrayChildVisitor for NamedChildrenCollector {
fn visit_child(&mut self, name: &str, array: &dyn Array) {
impl<'a> ArrayChildVisitor<'a> for NamedChildrenCollector {
fn visit_child(&mut self, name: &str, array: &'a dyn Array) {
self.children.push((name.to_string(), array.to_array()));
}
}
Expand Down
6 changes: 4 additions & 2 deletions vortex-array/src/array/transform/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ impl ArrayOptimizer {
let mut children_changed = false;

for child in children.iter() {
let child = child.to_array();
let optimized_child = self.apply_parent_rules(child.clone(), ctx)?;
children_changed |= !std::sync::Arc::ptr_eq(&optimized_child, child);
children_changed |= !std::sync::Arc::ptr_eq(&optimized_child, &child);
optimized_children.push(optimized_child);
}

Expand Down Expand Up @@ -128,8 +129,9 @@ impl ArrayOptimizer {
let mut changed = false;

for child in children.iter() {
let child = child.to_array();
let optimized_child = self.apply_reduce_rules(child.clone(), ctx)?;
changed |= !std::sync::Arc::ptr_eq(&optimized_child, child);
changed |= !std::sync::Arc::ptr_eq(&optimized_child, &child);
new_children.push(optimized_child);
}

Expand Down
106 changes: 106 additions & 0 deletions vortex-array/src/array/traversal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Array tree traversal using the Node trait infrastructure.
//!
//! This module provides implementations of [`Node`] and [`NodeContainer`] for [`ArrayRef`],
//! enabling powerful tree transformations on array structures similar to expression trees.

use itertools::Itertools;
use vortex_error::VortexResult;

use crate::array::ArrayVisitor;
use crate::expr::traversal::{Node, NodeContainer, Transformed, TraversalOrder};
use crate::{Array, ArrayRef};

impl<'a> NodeContainer<'a, Self> for ArrayRef {
fn apply_elements<F: FnMut(&'a Self) -> VortexResult<TraversalOrder>>(
&'a self,
mut f: F,
) -> VortexResult<TraversalOrder> {
f(self)
}

fn map_elements<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
self,
mut f: F,
) -> VortexResult<Transformed<Self>> {
f(self)
}
}

impl Node for ArrayRef {
type Child = dyn Array; // Back to dyn Array to match children() return type

fn apply_children<'a, F: FnMut(&'a Self::Child) -> VortexResult<TraversalOrder>>(
&'a self,
mut f: F,
) -> VortexResult<TraversalOrder> {
for child in self.children().iter() {
let order = f(*child)?; // *child is &'a dyn Array
if !matches!(order, TraversalOrder::Continue | TraversalOrder::Skip) {
return Ok(order);
}
}

Ok(TraversalOrder::Continue)
}

fn map_children<F: FnMut(Self) -> VortexResult<Transformed<Self>>>(
self,
f: F,
) -> VortexResult<Transformed<Self>> {
let transformed = self
.children()
.iter()
.map(|c| c.to_array())
.collect_vec()
.map_elements(f)?;

if transformed.changed {
Ok(Transformed {
value: self.with_children(transformed.value.as_ref())?,
order: transformed.order,
changed: true,
})
} else {
Ok(Transformed::no(self))
}
}

fn iter_children<T>(&self, f: impl FnOnce(&mut dyn Iterator<Item = &Self>) -> T) -> T {
// Convert Vec<&dyn Array> to Vec<ArrayRef> so we can create Iterator<Item = &ArrayRef>
let children: Vec<ArrayRef> = self.children().iter().map(|c| c.to_array()).collect();
f(&mut children.iter())
}

fn children_count(&self) -> usize {
self.nchildren()
}

/// For ArrayRef, manually implement the traversal to avoid visitor type mismatch
fn accept_on_child<'a, V: crate::expr::traversal::NodeVisitor<'a, NodeTy = Self>>(
child: &'a Self::Child,
visitor: &mut V,
) -> VortexResult<TraversalOrder> {
use crate::expr::traversal::TraversalOrder;

// Manually implement the traversal logic here:
// 1. Visit down on the child
let down_order = visitor.visit_down(child)?;

// 2. If we should continue, recurse through the child's children
let child_order = down_order.visit_children(|| {
for grandchild in child.children().iter() {
let order = Self::accept_on_child(*grandchild, visitor)?;
if !matches!(order, TraversalOrder::Continue | TraversalOrder::Skip) {
return Ok(order);
}
}
Ok(TraversalOrder::Continue)
})?;

// 3. Visit up on the child
child_order.visit_parent(|| visitor.visit_up(child))
}
}
Loading
Loading