-
Notifications
You must be signed in to change notification settings - Fork 130
Scalar functions #5561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Scalar functions #5561
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
713d527
Scalar Functions
gatesn e1ac853
Scalar Functions
gatesn 86a243e
Scalar Functions
gatesn b9ce093
Scalar Functions
gatesn 2c3da95
Scalar Functions
gatesn 1c0301e
Scalar Functions
gatesn b6b1d37
Scalar Functions
gatesn 253b115
Scalar Functions
gatesn a46cbda
Scalar Functions
gatesn 2906737
Merge branch 'develop' into ngates/scalar-functions
gatesn a557525
Scalar Functions
gatesn 86f3fe2
Scalar Functions
gatesn 1ce9988
Scalar Functions
gatesn 48aa6c1
Scalar Functions
gatesn b2977d1
Scalar Functions
gatesn 6d98782
Scalar Functions
gatesn 96ded23
Scalar Functions
gatesn eff3231
Scalar Functions
gatesn 948d425
Merge branch 'develop' into ngates/scalar-functions
gatesn c9c1c60
Scalar Functions
gatesn d69af9b
Scalar Functions
gatesn 7acb92a
Scalar Functions
gatesn ff1e9e1
Scalar Functions
gatesn 2e2c687
Scalar Functions
gatesn cba44e3
Move stats to expr
gatesn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use vortex_dtype::DType; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::expr::functions::scalar::ScalarFn; | ||
| use crate::stats::ArrayStats; | ||
| use crate::vtable::ArrayVTable; | ||
|
|
||
| #[derive(Clone, Debug)] | ||
| pub struct ScalarFnArray { | ||
| // NOTE(ngates): we should fix vtables so we don't have to hold this | ||
| pub(super) vtable: ArrayVTable, | ||
| pub(super) scalar_fn: ScalarFn, | ||
| pub(super) dtype: DType, | ||
| pub(super) len: usize, | ||
| pub(super) children: Vec<ArrayRef>, | ||
| pub(super) stats: ArrayStats, | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use vortex_dtype::DType; | ||
|
|
||
| use crate::expr::functions::scalar::ScalarFn; | ||
|
|
||
| #[derive(Clone, Debug)] | ||
| pub struct ScalarFnMetadata { | ||
| pub(super) scalar_fn: ScalarFn, | ||
| pub(super) child_dtypes: Vec<DType>, | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| mod array; | ||
| mod metadata; | ||
| mod vtable; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use std::hash::Hash; | ||
| use std::hash::Hasher; | ||
|
|
||
| use vortex_dtype::DType; | ||
|
|
||
| use crate::ArrayEq; | ||
| use crate::ArrayHash; | ||
| use crate::Precision; | ||
| use crate::arrays::scalar_fn::array::ScalarFnArray; | ||
| use crate::arrays::scalar_fn::vtable::ScalarFnVTable; | ||
| use crate::stats::StatsSetRef; | ||
| use crate::vtable::BaseArrayVTable; | ||
|
|
||
| impl BaseArrayVTable<ScalarFnVTable> for ScalarFnVTable { | ||
| fn len(array: &ScalarFnArray) -> usize { | ||
| array.len | ||
| } | ||
|
|
||
| fn dtype(array: &ScalarFnArray) -> &DType { | ||
| &array.dtype | ||
| } | ||
|
|
||
| fn stats(array: &ScalarFnArray) -> StatsSetRef<'_> { | ||
| array.stats.to_ref(array.as_ref()) | ||
| } | ||
|
|
||
| fn array_hash<H: Hasher>(array: &ScalarFnArray, state: &mut H, precision: Precision) { | ||
| array.len.hash(state); | ||
| array.dtype.hash(state); | ||
| array.scalar_fn.hash(state); | ||
| for child in &array.children { | ||
| child.array_hash(state, precision); | ||
| } | ||
| } | ||
|
|
||
| fn array_eq(array: &ScalarFnArray, other: &ScalarFnArray, precision: Precision) -> bool { | ||
| if array.len != other.len { | ||
| return false; | ||
| } | ||
| if array.dtype != other.dtype { | ||
| return false; | ||
| } | ||
| if array.scalar_fn != other.scalar_fn { | ||
| return false; | ||
| } | ||
| for (child, other_child) in array.children.iter().zip(other.children.iter()) { | ||
| if !child.array_eq(other_child, precision) { | ||
| return false; | ||
| } | ||
| } | ||
| true | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use itertools::Itertools; | ||
| use vortex_error::VortexExpect; | ||
| use vortex_vector::Datum; | ||
|
|
||
| use crate::Array; | ||
| use crate::Canonical; | ||
| use crate::arrays::scalar_fn::array::ScalarFnArray; | ||
| use crate::arrays::scalar_fn::vtable::ScalarFnVTable; | ||
| use crate::expr::functions::ExecutionCtx; | ||
| use crate::vectors::VectorIntoArray; | ||
| use crate::vtable::CanonicalVTable; | ||
|
|
||
| impl CanonicalVTable<ScalarFnVTable> for ScalarFnVTable { | ||
| fn canonicalize(array: &ScalarFnArray) -> Canonical { | ||
| let child_dtypes: Vec<_> = array.children.iter().map(|c| c.dtype().clone()).collect(); | ||
| let child_datums: Vec<_> = array | ||
| .children() | ||
| .iter() | ||
| // TODO(ngates): we could make all execution operate over datums | ||
| .map(|child| child.execute().map(Datum::Vector)) | ||
| .try_collect() | ||
| // FIXME(ngates): canonicalizing really ought to be fallible | ||
| .vortex_expect( | ||
| "Failed to execute child array during canonicalization of ScalarFnArray", | ||
| ); | ||
|
|
||
| let ctx = ExecutionCtx::new(array.len, array.dtype.clone(), child_dtypes, child_datums); | ||
|
|
||
| let result_vector = array | ||
| .scalar_fn | ||
| .execute(&ctx) | ||
| .vortex_expect("Canonicalize should be fallible") | ||
| .into_vector() | ||
| .vortex_expect("Canonicalize should return a vector"); | ||
|
|
||
| result_vector.into_array(&array.dtype).to_canonical() | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| mod array; | ||
| mod canonical; | ||
| mod operations; | ||
| mod validity; | ||
| mod visitor; | ||
|
|
||
| use itertools::Itertools; | ||
| use vortex_buffer::BufferHandle; | ||
| use vortex_dtype::DType; | ||
| use vortex_error::VortexExpect; | ||
| use vortex_error::VortexResult; | ||
| use vortex_error::vortex_bail; | ||
| use vortex_vector::Vector; | ||
|
|
||
| use crate::Array; | ||
| use crate::arrays::scalar_fn::array::ScalarFnArray; | ||
| use crate::arrays::scalar_fn::metadata::ScalarFnMetadata; | ||
| use crate::execution::ExecutionCtx; | ||
| use crate::expr::functions; | ||
| use crate::serde::ArrayChildren; | ||
| use crate::vtable; | ||
| use crate::vtable::ArrayId; | ||
| use crate::vtable::ArrayVTable; | ||
| use crate::vtable::ArrayVTableExt; | ||
| use crate::vtable::NotSupported; | ||
| use crate::vtable::VTable; | ||
|
|
||
| vtable!(ScalarFn); | ||
|
|
||
| #[derive(Clone, Debug)] | ||
| pub struct ScalarFnVTable { | ||
| vtable: functions::ScalarFnVTable, | ||
| } | ||
|
|
||
| impl VTable for ScalarFnVTable { | ||
| type Array = ScalarFnArray; | ||
| type Metadata = ScalarFnMetadata; | ||
| type ArrayVTable = Self; | ||
| type CanonicalVTable = Self; | ||
| type OperationsVTable = NotSupported; | ||
| type ValidityVTable = Self; | ||
| type VisitorVTable = Self; | ||
| type ComputeVTable = NotSupported; | ||
| type EncodeVTable = NotSupported; | ||
| type OperatorVTable = NotSupported; | ||
|
|
||
| fn id(&self) -> ArrayId { | ||
| self.vtable.id() | ||
| } | ||
|
|
||
| fn encoding(array: &Self::Array) -> ArrayVTable { | ||
| array.vtable.clone() | ||
| } | ||
|
|
||
| fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> { | ||
| let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect(); | ||
| Ok(ScalarFnMetadata { | ||
| scalar_fn: array.scalar_fn.clone(), | ||
| child_dtypes, | ||
| }) | ||
| } | ||
|
|
||
| fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> { | ||
| // Not supported | ||
| Ok(None) | ||
| } | ||
|
|
||
| fn deserialize(_bytes: &[u8]) -> VortexResult<Self::Metadata> { | ||
| vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported"); | ||
| } | ||
|
|
||
| fn build( | ||
| &self, | ||
| dtype: &DType, | ||
| len: usize, | ||
| metadata: &ScalarFnMetadata, | ||
| _buffers: &[BufferHandle], | ||
| children: &dyn ArrayChildren, | ||
| ) -> VortexResult<Self::Array> { | ||
| let children: Vec<_> = metadata | ||
| .child_dtypes | ||
| .iter() | ||
| .enumerate() | ||
| .map(|(idx, child_dtype)| children.get(idx, child_dtype, len)) | ||
| .try_collect()?; | ||
|
|
||
| #[cfg(debug_assertions)] | ||
| { | ||
| let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect(); | ||
| vortex_error::vortex_ensure!( | ||
| &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype, | ||
| "Return dtype mismatch when building ScalarFnArray" | ||
| ); | ||
| } | ||
|
|
||
| Ok(ScalarFnArray { | ||
| // This requires a new Arc, but we plan to remove this later anyway. | ||
| vtable: self.to_vtable(), | ||
| scalar_fn: metadata.scalar_fn.clone(), | ||
| dtype: dtype.clone(), | ||
| len, | ||
| children, | ||
| stats: Default::default(), | ||
| }) | ||
| } | ||
|
|
||
| fn execute(array: &Self::Array, _ctx: &mut dyn ExecutionCtx) -> VortexResult<Vector> { | ||
| let input_dtypes: Vec<_> = array.children().iter().map(|c| c.dtype().clone()).collect(); | ||
| let input_datums = array | ||
| .children() | ||
| .iter() | ||
| .map(|child| child.execute()) | ||
| .try_collect()?; | ||
| let ctx = functions::ExecutionCtx::new( | ||
| array.len(), | ||
| array.dtype.clone(), | ||
| input_dtypes, | ||
| input_datums, | ||
| ); | ||
| Ok(array | ||
| .scalar_fn | ||
| .execute(&ctx)? | ||
| .into_vector() | ||
| .vortex_expect("Vector inputs should return vector outputs")) | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use std::ops::Range; | ||
|
|
||
| use vortex_error::VortexExpect; | ||
| use vortex_scalar::Scalar; | ||
| use vortex_vector::Datum; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::IntoArray; | ||
| use crate::arrays::scalar_fn::array::ScalarFnArray; | ||
| use crate::arrays::scalar_fn::vtable::ScalarFnVTable; | ||
| use crate::expr::functions::ExecutionCtx; | ||
| use crate::vtable::OperationsVTable; | ||
|
|
||
| impl OperationsVTable<ScalarFnVTable> for ScalarFnVTable { | ||
| fn slice(array: &ScalarFnArray, range: Range<usize>) -> ArrayRef { | ||
| let children: Vec<_> = array | ||
| .children() | ||
| .iter() | ||
| .map(|c| c.slice(range.clone())) | ||
| .collect(); | ||
|
|
||
| ScalarFnArray { | ||
| vtable: array.vtable.clone(), | ||
| scalar_fn: array.scalar_fn.clone(), | ||
| dtype: array.dtype.clone(), | ||
| len: range.len(), | ||
| children, | ||
| stats: Default::default(), | ||
| } | ||
| .into_array() | ||
| } | ||
|
|
||
| fn scalar_at(array: &ScalarFnArray, index: usize) -> Scalar { | ||
| // TODO(ngates): we should evaluate the scalar function over the scalar inputs. | ||
| let input_datums: Vec<_> = array | ||
| .children() | ||
| .iter() | ||
| .map(|c| c.scalar_at(index)) | ||
| .map(|scalar| Datum::from(scalar.to_vector_scalar())) | ||
| .collect(); | ||
|
|
||
| let ctx = ExecutionCtx::new( | ||
| 1, | ||
| array.dtype.clone(), | ||
| array.children().iter().map(|s| s.dtype().clone()).collect(), | ||
| input_datums, | ||
| ); | ||
|
|
||
| let _result = array | ||
| .scalar_fn | ||
| .execute(&ctx) | ||
| .vortex_expect("Scalar function execution should be fallible") | ||
| .into_scalar() | ||
| .vortex_expect("Scalar function execution should return scalar"); | ||
|
|
||
| // Convert the vector scalar back into a legacy Scalar for now. | ||
| todo!("Implement legacy scalar conversion") | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use vortex_error::VortexExpect; | ||
| use vortex_mask::Mask; | ||
|
|
||
| use crate::Array; | ||
| use crate::arrays::scalar_fn::array::ScalarFnArray; | ||
| use crate::arrays::scalar_fn::vtable::ScalarFnVTable; | ||
| use crate::expr::functions::NullHandling; | ||
| use crate::vtable::ValidityVTable; | ||
|
|
||
| impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable { | ||
| fn is_valid(array: &ScalarFnArray, index: usize) -> bool { | ||
| array.scalar_at(index).is_valid() | ||
| } | ||
|
|
||
| fn all_valid(array: &ScalarFnArray) -> bool { | ||
| match array.scalar_fn.signature().null_handling() { | ||
| NullHandling::Propagate | NullHandling::AbsorbsNull => { | ||
| // Requires all children to guarantee all_valid | ||
| array.children().iter().all(|child| child.all_valid()) | ||
| } | ||
| NullHandling::Custom => { | ||
| // We cannot guarantee that the array is all valid without evaluating the function | ||
| false | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn all_invalid(array: &ScalarFnArray) -> bool { | ||
| match array.scalar_fn.signature().null_handling() { | ||
| NullHandling::Propagate => { | ||
| // All null if any child is all null | ||
| array.children().iter().any(|child| child.all_invalid()) | ||
| } | ||
| NullHandling::AbsorbsNull | NullHandling::Custom => { | ||
| // We cannot guarantee that the array is all valid without evaluating the function | ||
| false | ||
| } | ||
| } | ||
| } | ||
|
|
||
| fn validity_mask(array: &ScalarFnArray) -> Mask { | ||
| let vector = array | ||
| .execute() | ||
| .vortex_expect("Validity mask computation should be fallible"); | ||
| Mask::from_buffer(vector.into_bool().into_bits()) | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is approximate which I am not sure was the original behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, but its only used for approximate short-circuits and we have to move the Array trait towards this definition anyway since we want to defer compute.