-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: collect to list (non-windowed) (primitive/strings/booleans) (#569)
Adds the `collect` function for non-windowed, primitive/string/boolean types. This doesn't support `LargeUtf8` because the other string aggregations that use the `create_typed_evaluator` macro don't support generic offset size, so that will come in a follow up. --------- Co-authored-by: Kevin J Nguyen <kevin.nguyen@datastax.com>
- Loading branch information
1 parent
f58306e
commit 0f33802
Showing
13 changed files
with
816 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
crates/sparrow-instructions/src/evaluators/aggregation/token/collect_token.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
use serde::de::DeserializeOwned; | ||
use serde::Serialize; | ||
use std::collections::VecDeque; | ||
|
||
use crate::{ComputeStore, StateToken, StoreKey}; | ||
|
||
/// State token used for the lag operator. | ||
#[derive(Default, Debug)] | ||
pub struct CollectToken<T> | ||
where | ||
T: Clone, | ||
T: Serialize + DeserializeOwned, | ||
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned, | ||
{ | ||
state: Vec<VecDeque<Option<T>>>, | ||
} | ||
|
||
impl<T> CollectToken<T> | ||
where | ||
T: Clone, | ||
T: Serialize + DeserializeOwned, | ||
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned, | ||
{ | ||
pub fn resize(&mut self, len: usize) { | ||
if len >= self.state.len() { | ||
self.state.resize(len + 1, VecDeque::new()); | ||
} | ||
} | ||
|
||
pub fn add_value(&mut self, max: usize, index: usize, input: Option<T>) { | ||
self.state[index].push_back(input); | ||
if self.state[index].len() > max { | ||
self.state[index].pop_front(); | ||
} | ||
} | ||
|
||
pub fn state(&self, index: usize) -> &VecDeque<Option<T>> { | ||
&self.state[index] | ||
} | ||
} | ||
|
||
impl<T> StateToken for CollectToken<T> | ||
where | ||
T: Clone, | ||
T: Serialize + DeserializeOwned, | ||
Vec<VecDeque<Option<T>>>: Serialize + DeserializeOwned, | ||
{ | ||
fn restore(&mut self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> { | ||
if let Some(state) = store.get(key)? { | ||
self.state = state; | ||
} else { | ||
self.state.clear(); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn store(&self, key: &StoreKey, store: &ComputeStore) -> anyhow::Result<()> { | ||
store.put(key, &self.state) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,11 @@ | ||
mod collect_boolean; | ||
mod collect_map; | ||
mod collect_primitive; | ||
mod collect_string; | ||
mod index; | ||
|
||
pub(super) use collect_boolean::*; | ||
pub(super) use collect_map::*; | ||
pub(super) use collect_primitive::*; | ||
pub(super) use collect_string::*; | ||
pub(super) use index::*; |
108 changes: 108 additions & 0 deletions
108
crates/sparrow-instructions/src/evaluators/list/collect_boolean.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
use crate::{CollectToken, Evaluator, EvaluatorFactory, RuntimeInfo, StateToken, StaticInfo}; | ||
use arrow::array::{ArrayRef, AsArray, BooleanBuilder, ListBuilder}; | ||
use arrow::datatypes::DataType; | ||
use itertools::izip; | ||
use sparrow_arrow::scalar_value::ScalarValue; | ||
use sparrow_plan::ValueRef; | ||
use std::sync::Arc; | ||
|
||
/// Evaluator for the `collect` instruction. | ||
/// | ||
/// Collects a stream of values into a List. A list is produced | ||
/// for each input value received, growing up to a maximum size. | ||
/// | ||
/// If the list is empty, an empty list is returned (rather than `null`). | ||
#[derive(Debug)] | ||
pub struct CollectBooleanEvaluator { | ||
/// The max size of the buffer. | ||
/// | ||
/// Once the max size is reached, the front will be popped and the new | ||
/// value pushed to the back. | ||
max: usize, | ||
input: ValueRef, | ||
tick: ValueRef, | ||
duration: ValueRef, | ||
/// Contains the buffer of values for each entity | ||
token: CollectToken<bool>, | ||
} | ||
|
||
impl EvaluatorFactory for CollectBooleanEvaluator { | ||
fn try_new(info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> { | ||
let input_type = info.args[1].data_type(); | ||
let result_type = info.result_type; | ||
match result_type { | ||
DataType::List(t) => anyhow::ensure!(t.data_type() == input_type), | ||
other => anyhow::bail!("expected list result type, saw {:?}", other), | ||
}; | ||
|
||
let max = match info.args[0].value_ref.literal_value() { | ||
Some(ScalarValue::Int64(Some(v))) if *v <= 0 => { | ||
anyhow::bail!("unexpected value of `max` -- must be > 0") | ||
} | ||
Some(ScalarValue::Int64(Some(v))) => *v as usize, | ||
// If a user specifies `max = null`, we use usize::MAX value as a way | ||
// to have an "unlimited" buffer. | ||
Some(ScalarValue::Int64(None)) => usize::MAX, | ||
Some(other) => anyhow::bail!("expected i64 for max parameter, saw {:?}", other), | ||
None => anyhow::bail!("expected literal value for max parameter"), | ||
}; | ||
|
||
let (_, input, tick, duration) = info.unpack_arguments()?; | ||
Ok(Box::new(Self { | ||
max, | ||
input, | ||
tick, | ||
duration, | ||
token: CollectToken::default(), | ||
})) | ||
} | ||
} | ||
|
||
impl Evaluator for CollectBooleanEvaluator { | ||
fn evaluate(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> { | ||
match (self.tick.is_literal_null(), self.duration.is_literal_null()) { | ||
(true, true) => self.evaluate_non_windowed(info), | ||
(true, false) => unimplemented!("since window aggregation unsupported"), | ||
(false, false) => panic!("sliding window aggregation should use other evaluator"), | ||
(_, _) => anyhow::bail!("saw invalid combination of tick and duration"), | ||
} | ||
} | ||
|
||
fn state_token(&self) -> Option<&dyn StateToken> { | ||
Some(&self.token) | ||
} | ||
|
||
fn state_token_mut(&mut self) -> Option<&mut dyn StateToken> { | ||
Some(&mut self.token) | ||
} | ||
} | ||
|
||
impl CollectBooleanEvaluator { | ||
fn ensure_entity_capacity(&mut self, len: usize) { | ||
self.token.resize(len) | ||
} | ||
|
||
fn evaluate_non_windowed(&mut self, info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> { | ||
let input = info.value(&self.input)?.array_ref()?; | ||
let key_capacity = info.grouping().num_groups(); | ||
let entity_indices = info.grouping().group_indices(); | ||
assert_eq!(entity_indices.len(), input.len()); | ||
|
||
self.ensure_entity_capacity(key_capacity); | ||
|
||
let input = input.as_boolean(); | ||
let builder = BooleanBuilder::new(); | ||
let mut list_builder = ListBuilder::new(builder); | ||
|
||
izip!(entity_indices.values(), input).for_each(|(entity_index, input)| { | ||
let entity_index = *entity_index as usize; | ||
|
||
self.token.add_value(self.max, entity_index, input); | ||
let cur_list = self.token.state(entity_index); | ||
|
||
list_builder.append_value(cur_list.iter().copied()); | ||
}); | ||
|
||
Ok(Arc::new(list_builder.finish())) | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
crates/sparrow-instructions/src/evaluators/list/collect_map.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
use crate::{Evaluator, EvaluatorFactory, RuntimeInfo, StaticInfo}; | ||
use arrow::array::ArrayRef; | ||
use sparrow_plan::ValueRef; | ||
|
||
/// Evaluator for the `collect` instruction. | ||
/// | ||
/// Collect collects a stream of values into a List<T>. A list is produced | ||
/// for each input value received, growing up to a maximum size. | ||
#[derive(Debug)] | ||
pub struct CollectMapEvaluator { | ||
/// The max size of the buffer. | ||
/// | ||
/// Once the max size is reached, the front will be popped and the new | ||
/// value pushed to the back. | ||
_max: i64, | ||
_input: ValueRef, | ||
_tick: ValueRef, | ||
_duration: ValueRef, | ||
} | ||
|
||
impl EvaluatorFactory for CollectMapEvaluator { | ||
fn try_new(_info: StaticInfo<'_>) -> anyhow::Result<Box<dyn Evaluator>> { | ||
unimplemented!("map collect evaluator is unsupported") | ||
} | ||
} | ||
|
||
impl Evaluator for CollectMapEvaluator { | ||
fn evaluate(&mut self, _info: &dyn RuntimeInfo) -> anyhow::Result<ArrayRef> { | ||
unimplemented!("map collect evaluator is unsupported") | ||
} | ||
} |
Oops, something went wrong.