Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions vortex-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

pub mod display;
mod operator;
pub mod session;
pub mod transform;
mod visitor;

use std::any::Any;
Expand All @@ -15,7 +17,7 @@ pub use operator::*;
pub use visitor::*;
use vortex_buffer::ByteBuffer;
use vortex_dtype::{DType, Nullability};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err, vortex_panic};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
use vortex_mask::Mask;
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -617,18 +619,9 @@ impl<V: VTable> Array for ArrayAdapter<V> {
}
}

let metadata = self.metadata()?.ok_or_else(|| {
vortex_err!("Cannot replace children for arrays that do not support serialization")
})?;

// Replace the children of the array by re-building the array from parts.
self.encoding().build(
self.dtype(),
self.len(),
&metadata,
&self.buffers(),
&ReplacementChildren { children },
)
self.encoding()
.with_children(self, &ReplacementChildren { children })
}

fn invoke(
Expand Down
22 changes: 0 additions & 22 deletions vortex-array/src/array/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ pub trait ArrayOperator: 'static + Send + Sync {
/// If the array's implementation returns an invalid vector (wrong length, wrong type, etc.).
fn execute_batch(&self, ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector>;

/// Optimize the array by running the optimization rules.
fn reduce(&self) -> VortexResult<Option<ArrayRef>>;

/// Optimize the array by pushing down a parent array.
fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;

/// Returns the array as a pipeline node, if supported.
fn as_pipelined(&self) -> Option<&dyn PipelinedNode>;

Expand All @@ -49,14 +43,6 @@ impl ArrayOperator for Arc<dyn Array> {
self.as_ref().execute_batch(ctx)
}

fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
self.as_ref().reduce()
}

fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
self.as_ref().reduce_parent(parent, child_idx)
}

fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
self.as_ref().as_pipelined()
}
Expand Down Expand Up @@ -88,14 +74,6 @@ impl<V: VTable> ArrayOperator for ArrayAdapter<V> {
Ok(vector)
}

fn reduce(&self) -> VortexResult<Option<ArrayRef>> {
<V::OperatorVTable as OperatorVTable<V>>::reduce(&self.0)
}

fn reduce_parent(&self, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>> {
<V::OperatorVTable as OperatorVTable<V>>::reduce_parent(&self.0, parent, child_idx)
}

fn as_pipelined(&self) -> Option<&dyn PipelinedNode> {
<V::OperatorVTable as OperatorVTable<V>>::pipeline_node(&self.0)
}
Expand Down
6 changes: 6 additions & 0 deletions vortex-array/src/array/session/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

pub mod rewrite;

pub use rewrite::ArrayRewriteRuleRegistry;
236 changes: 236 additions & 0 deletions vortex-array/src/array/session/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::marker::PhantomData;
use std::sync::Arc;

use vortex_error::VortexResult;
use vortex_utils::aliases::dash_map::DashMap;

use crate::EncodingId;
use crate::array::ArrayRef;
use crate::array::transform::context::ArrayRuleContext;
use crate::array::transform::rules::{
AnyParent, ArrayParentMatcher, ArrayParentReduceRule, ArrayReduceRule,
};
use crate::vtable::VTable;

/// Dynamic trait for array reduce rules
pub trait DynArrayReduceRule: Send + Sync {
fn reduce(&self, array: &ArrayRef, ctx: &ArrayRuleContext) -> VortexResult<Option<ArrayRef>>;
}

/// Dynamic trait for array parent reduce rules
pub trait DynArrayParentReduceRule: Send + Sync {
fn reduce_parent(
&self,
array: &ArrayRef,
parent: &ArrayRef,
child_idx: usize,
ctx: &ArrayRuleContext,
) -> VortexResult<Option<ArrayRef>>;
}

/// Adapter for ArrayReduceRule
struct ArrayReduceRuleAdapter<V: VTable, R> {
rule: R,
_phantom: PhantomData<V>,
}

/// Adapter for ArrayParentReduceRule
struct ArrayParentReduceRuleAdapter<Child: VTable, Parent: ArrayParentMatcher, R> {
rule: R,
_phantom: PhantomData<(Child, Parent)>,
}

impl<V, R> DynArrayReduceRule for ArrayReduceRuleAdapter<V, R>
where
V: VTable,
R: ArrayReduceRule<V>,
{
fn reduce(&self, array: &ArrayRef, ctx: &ArrayRuleContext) -> VortexResult<Option<ArrayRef>> {
let Some(view) = array.as_opt::<V>() else {
return Ok(None);
};
self.rule.reduce(view, ctx)
}
}

impl<Child, Parent, R> DynArrayParentReduceRule for ArrayParentReduceRuleAdapter<Child, Parent, R>
where
Child: VTable,
Parent: ArrayParentMatcher,
R: ArrayParentReduceRule<Child, Parent>,
{
fn reduce_parent(
&self,
array: &ArrayRef,
parent: &ArrayRef,
child_idx: usize,
ctx: &ArrayRuleContext,
) -> VortexResult<Option<ArrayRef>> {
let Some(view) = array.as_opt::<Child>() else {
return Ok(None);
};
let Some(parent_view) = Parent::try_match(parent) else {
return Ok(None);
};
self.rule.reduce_parent(view, parent_view, child_idx, ctx)
}
}

/// Inner struct that holds all the rule registries.
/// Wrapped in a single Arc by ArrayRewriteRuleRegistry for efficient cloning.
#[derive(Default)]
struct ArrayRewriteRuleRegistryInner {
/// Reduce rules indexed by encoding ID
reduce_rules: DashMap<EncodingId, Vec<Arc<dyn DynArrayReduceRule>>>,
/// Parent reduce rules for specific parent types, indexed by (child_id, parent_id)
parent_rules: DashMap<(EncodingId, EncodingId), Vec<Arc<dyn DynArrayParentReduceRule>>>,
/// Wildcard parent rules (match any parent), indexed by child_id only
any_parent_rules: DashMap<EncodingId, Vec<Arc<dyn DynArrayParentReduceRule>>>,
}

impl std::fmt::Debug for ArrayRewriteRuleRegistryInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ArrayRewriteRuleRegistryInner")
.field(
"reduce_rules",
&format!("{} encodings", self.reduce_rules.len()),
)
.field(
"parent_rules",
&format!("{} pairs", self.parent_rules.len()),
)
.field(
"any_parent_rules",
&format!("{} encodings", self.any_parent_rules.len()),
)
.finish()
}
}

/// Registry of array rewrite rules.
///
/// Stores rewrite rules indexed by the encoding ID they apply to.
#[derive(Clone, Debug)]
pub struct ArrayRewriteRuleRegistry {
inner: Arc<ArrayRewriteRuleRegistryInner>,
}

impl Default for ArrayRewriteRuleRegistry {
fn default() -> Self {
Self {
inner: Arc::new(ArrayRewriteRuleRegistryInner::default()),
}
}
}

impl ArrayRewriteRuleRegistry {
pub fn new() -> Self {
Self::default()
}

/// Register a reduce rule for a specific array encoding.
pub fn register_reduce_rule<V, R>(&self, encoding: &V::Encoding, rule: R)
where
V: VTable,
R: 'static,
R: ArrayReduceRule<V>,
{
let adapter = ArrayReduceRuleAdapter {
rule,
_phantom: PhantomData,
};
let encoding_id = V::id(encoding);
self.inner
.reduce_rules
.entry(encoding_id)
.or_default()
.push(Arc::new(adapter));
}

/// Register a parent rule for a specific parent type.
pub fn register_parent_rule<Child, Parent, R>(
&self,
child_encoding: &Child::Encoding,
parent_encoding: &Parent::Encoding,
rule: R,
) where
Child: VTable,
Parent: VTable,
R: 'static,
R: ArrayParentReduceRule<Child, Parent>,
{
let adapter = ArrayParentReduceRuleAdapter {
rule,
_phantom: PhantomData,
};
let child_id = Child::id(child_encoding);
let parent_id = Parent::id(parent_encoding);
self.inner
.parent_rules
.entry((child_id, parent_id))
.or_default()
.push(Arc::new(adapter));
}

/// Register a parent rule that matches ANY parent type (wildcard).
pub fn register_any_parent_rule<Child, R>(&self, child_encoding: &Child::Encoding, rule: R)
where
Child: VTable,
R: 'static,
R: ArrayParentReduceRule<Child, AnyParent>,
{
let adapter = ArrayParentReduceRuleAdapter {
rule,
_phantom: PhantomData,
};
let child_id = Child::id(child_encoding);
self.inner
.any_parent_rules
.entry(child_id)
.or_default()
.push(Arc::new(adapter));
}

/// Execute a callback with all reduce rules for a given encoding ID.
pub(crate) fn with_reduce_rules<F, R>(&self, id: &EncodingId, f: F) -> R
where
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayReduceRule>) -> R,
{
f(&mut self
.inner
.reduce_rules
.get(id)
.iter()
.flat_map(|v| v.value())
.map(|arc| arc.as_ref()))
}

/// Execute a callback with all parent reduce rules for a given child and parent encoding ID.
///
/// Returns rules from both specific parent rules (if parent_id provided) and "any parent" wildcard rules.
pub(crate) fn with_parent_rules<F, R>(
&self,
child_id: &EncodingId,
parent_id: Option<&EncodingId>,
f: F,
) -> R
where
F: FnOnce(&mut dyn Iterator<Item = &dyn DynArrayParentReduceRule>) -> R,
{
let specific_entry = parent_id.and_then(|pid| {
self.inner
.parent_rules
.get(&(child_id.clone(), pid.clone()))
});
let wildcard_entry = self.inner.any_parent_rules.get(child_id);

f(&mut specific_entry
.iter()
.flat_map(|v| v.value())
.chain(wildcard_entry.iter().flat_map(|v| v.value()))
.map(|arc| arc.as_ref()))
}
}
24 changes: 24 additions & 0 deletions vortex-array/src/array/transform/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use crate::expr::transform::ExprOptimizer;

/// Rule context for array rewrite rules
///
/// Provides access to the expression optimizer for optimizing expressions
/// embedded in arrays. Note that dtype is not included since arrays already
/// have a dtype that can be accessed directly.
#[derive(Debug, Clone)]
pub struct ArrayRuleContext {
expr_optimizer: ExprOptimizer,
}

impl ArrayRuleContext {
pub fn new(expr_optimizer: ExprOptimizer) -> Self {
Self { expr_optimizer }
}

pub fn expr_optimizer(&self) -> &ExprOptimizer {
&self.expr_optimizer
}
}
12 changes: 12 additions & 0 deletions vortex-array/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

pub mod context;
pub mod optimizer;
pub mod rules;
#[cfg(test)]
mod tests;

pub use context::ArrayRuleContext;
pub use optimizer::ArrayOptimizer;
pub use rules::{AnyParent, ArrayParentMatcher, ArrayParentReduceRule, ArrayReduceRule};
Loading
Loading