diff --git a/src/common/data_type.rs b/src/common/data_type.rs index d55a0e86..622e1aa4 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use datafusion::arrow::datatypes::DataType; -use datafusion_common::DataFusionError; +use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; +use datafusion_common::{DataFusionError, ScalarValue}; use pyo3::prelude::*; use crate::errors::py_datafusion_err; @@ -130,9 +130,11 @@ impl DataTypeMap { PythonType::Float, SqlType::FLOAT, )), - DataType::Timestamp(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), + DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new( + DataType::Timestamp(unit.clone(), tz.clone()), + PythonType::Datetime, + SqlType::DATE, + )), DataType::Date32 => Ok(DataTypeMap::new( DataType::Date32, PythonType::Datetime, @@ -143,18 +145,28 @@ impl DataTypeMap { PythonType::Datetime, SqlType::DATE, )), - DataType::Time32(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), - DataType::Time64(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), + DataType::Time32(unit) => Ok(DataTypeMap::new( + DataType::Time32(unit.clone()), + PythonType::Datetime, + SqlType::DATE, + )), + DataType::Time64(unit) => Ok(DataTypeMap::new( + DataType::Time64(unit.clone()), + PythonType::Datetime, + SqlType::DATE, + )), DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( format!("{:?}", arrow_type), ))), - DataType::Interval(_) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), + DataType::Interval(interval_unit) => Ok(DataTypeMap::new( + DataType::Interval(interval_unit.clone()), + PythonType::Datetime, + match interval_unit { + IntervalUnit::DayTime => SqlType::INTERVAL_DAY, + IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH, + IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH, + }, + )), DataType::Binary => Ok(DataTypeMap::new( DataType::Binary, PythonType::Bytes, @@ -197,12 +209,16 @@ impl DataTypeMap { DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( format!("{:?}", arrow_type), ))), - DataType::Decimal128(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), - DataType::Decimal256(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( - format!("{:?}", arrow_type), - ))), + DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new( + DataType::Decimal128(*precision, *scale), + PythonType::Float, + SqlType::DECIMAL, + )), + DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new( + DataType::Decimal256(*precision, *scale), + PythonType::Float, + SqlType::DECIMAL, + )), DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented( format!("{:?}", arrow_type), ))), @@ -211,6 +227,69 @@ impl DataTypeMap { )), } } + + /// Generate the `DataTypeMap` from a `ScalarValue` instance + pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result { + DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?) + } + + /// Maps a `ScalarValue` to an Arrow `DataType` + pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result { + match scalar_val { + ScalarValue::Boolean(_) => Ok(DataType::Boolean), + ScalarValue::Float32(_) => Ok(DataType::Float32), + ScalarValue::Float64(_) => Ok(DataType::Float64), + ScalarValue::Decimal128(_, precision, scale) => { + Ok(DataType::Decimal128(*precision, *scale)) + } + ScalarValue::Dictionary(data_type, scalar_type) => { + // Call this function again to map the dictionary scalar_value to an Arrow type + Ok(DataType::Dictionary( + Box::new(*data_type.clone()), + Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?), + )) + } + ScalarValue::Int8(_) => Ok(DataType::Int8), + ScalarValue::Int16(_) => Ok(DataType::Int16), + ScalarValue::Int32(_) => Ok(DataType::Int32), + ScalarValue::Int64(_) => Ok(DataType::Int64), + ScalarValue::UInt8(_) => Ok(DataType::UInt8), + ScalarValue::UInt16(_) => Ok(DataType::UInt16), + ScalarValue::UInt32(_) => Ok(DataType::UInt32), + ScalarValue::UInt64(_) => Ok(DataType::UInt64), + ScalarValue::Utf8(_) => Ok(DataType::Utf8), + ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8), + ScalarValue::Binary(_) => Ok(DataType::Binary), + ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary), + ScalarValue::Date32(_) => Ok(DataType::Date32), + ScalarValue::Date64(_) => Ok(DataType::Date64), + ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)), + ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)), + ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)), + ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)), + ScalarValue::Null => Ok(DataType::Null), + ScalarValue::TimestampSecond(_, tz) => { + Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned())) + } + ScalarValue::TimestampMillisecond(_, tz) => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned())) + } + ScalarValue::TimestampMicrosecond(_, tz) => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned())) + } + ScalarValue::TimestampNanosecond(_, tz) => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned())) + } + ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)), + ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)), + ScalarValue::IntervalMonthDayNano(..) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + ScalarValue::List(_val, field_ref) => Ok(DataType::List(field_ref.to_owned())), + ScalarValue::Struct(_, fields) => Ok(DataType::Struct(fields.to_owned())), + ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)), + } + } } #[pymethods] diff --git a/src/expr.rs b/src/expr.rs index 4ada4c16..c002b329 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -15,19 +15,26 @@ // specific language governing permissions and limitations // under the License. +use datafusion_common::DFField; +use datafusion_expr::expr::{AggregateFunction, Sort, WindowFunction}; +use datafusion_expr::utils::exprlist_to_fields; use pyo3::{basic::CompareOp, prelude::*}; use std::convert::{From, Into}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::PyArrowType; -use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField}; +use datafusion_expr::{ + col, lit, Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, Like, LogicalPlan, Operator, + TryCast, +}; -use crate::common::data_type::RexType; -use crate::errors::py_runtime_err; +use crate::common::data_type::{DataTypeMap, RexType}; +use crate::errors::{py_runtime_err, py_type_err, DataFusionError}; use crate::expr::aggregate_expr::PyAggregateFunction; use crate::expr::binary_expr::PyBinaryExpr; use crate::expr::column::PyColumn; use crate::expr::literal::PyLiteral; +use crate::sql::logical::PyLogicalPlan; use datafusion::scalar::ScalarValue; use self::alias::PyAlias; @@ -274,11 +281,296 @@ impl PyExpr { Expr::ScalarSubquery(..) => RexType::ScalarSubquery, }) } + + /// Given the current `Expr` return the DataTypeMap which represents the + /// PythonType, Arrow DataType, and SqlType Enum which represents + pub fn types(&self) -> PyResult { + Self::_types(&self.expr) + } + + /// Extracts the Expr value into a PyObject that can be shared with Python + pub fn python_value(&self, py: Python) -> PyResult { + match &self.expr { + Expr::Literal(scalar_value) => Ok(match scalar_value { + ScalarValue::Null => todo!(), + ScalarValue::Boolean(v) => v.into_py(py), + ScalarValue::Float32(v) => v.into_py(py), + ScalarValue::Float64(v) => v.into_py(py), + ScalarValue::Decimal128(_, _, _) => todo!(), + ScalarValue::Int8(v) => v.into_py(py), + ScalarValue::Int16(v) => v.into_py(py), + ScalarValue::Int32(v) => v.into_py(py), + ScalarValue::Int64(v) => v.into_py(py), + ScalarValue::UInt8(v) => v.into_py(py), + ScalarValue::UInt16(v) => v.into_py(py), + ScalarValue::UInt32(v) => v.into_py(py), + ScalarValue::UInt64(v) => v.into_py(py), + ScalarValue::Utf8(v) => v.clone().into_py(py), + ScalarValue::LargeUtf8(v) => v.clone().into_py(py), + ScalarValue::Binary(v) => v.clone().into_py(py), + ScalarValue::FixedSizeBinary(_, _) => todo!(), + ScalarValue::LargeBinary(v) => v.clone().into_py(py), + ScalarValue::List(_, _) => todo!(), + ScalarValue::Date32(v) => v.into_py(py), + ScalarValue::Date64(v) => v.into_py(py), + ScalarValue::Time32Second(v) => v.into_py(py), + ScalarValue::Time32Millisecond(v) => v.into_py(py), + ScalarValue::Time64Microsecond(v) => v.into_py(py), + ScalarValue::Time64Nanosecond(v) => v.into_py(py), + ScalarValue::TimestampSecond(_, _) => todo!(), + ScalarValue::TimestampMillisecond(_, _) => todo!(), + ScalarValue::TimestampMicrosecond(_, _) => todo!(), + ScalarValue::TimestampNanosecond(_, _) => todo!(), + ScalarValue::IntervalYearMonth(v) => v.into_py(py), + ScalarValue::IntervalDayTime(v) => v.into_py(py), + ScalarValue::IntervalMonthDayNano(v) => v.into_py(py), + ScalarValue::Struct(_, _) => todo!(), + ScalarValue::Dictionary(_, _) => todo!(), + }), + _ => Err(py_type_err(format!( + "Non Expr::Literal encountered in types: {:?}", + &self.expr + ))), + } + } + + /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s), + /// store those operands in different datastructures. This function examines the Expr variant and returns + /// the operands to the calling logic as a Vec of PyExpr instances. + pub fn rex_call_operands(&self) -> PyResult> { + match &self.expr { + // Expr variants that are themselves the operand to return + Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => { + Ok(vec![PyExpr::from(self.expr.clone())]) + } + + // Expr(s) that house the Expr instance to return in their bounded params + Expr::Alias(expr, ..) + | Expr::Not(expr) + | Expr::IsNull(expr) + | Expr::IsNotNull(expr) + | Expr::IsTrue(expr) + | Expr::IsFalse(expr) + | Expr::IsUnknown(expr) + | Expr::IsNotTrue(expr) + | Expr::IsNotFalse(expr) + | Expr::IsNotUnknown(expr) + | Expr::Negative(expr) + | Expr::GetIndexedField(GetIndexedField { expr, .. }) + | Expr::Cast(Cast { expr, .. }) + | Expr::TryCast(TryCast { expr, .. }) + | Expr::Sort(Sort { expr, .. }) + | Expr::InSubquery { expr, .. } => Ok(vec![PyExpr::from(*expr.clone())]), + + // Expr variants containing a collection of Expr(s) for operands + Expr::AggregateFunction(AggregateFunction { args, .. }) + | Expr::AggregateUDF { args, .. } + | Expr::ScalarFunction { args, .. } + | Expr::ScalarUDF { args, .. } + | Expr::WindowFunction(WindowFunction { args, .. }) => { + Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()) + } + + // Expr(s) that require more specific processing + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => { + let mut operands: Vec = Vec::new(); + + if let Some(e) = expr { + operands.push(PyExpr::from(*e.clone())); + }; + + for (when, then) in when_then_expr { + operands.push(PyExpr::from(*when.clone())); + operands.push(PyExpr::from(*then.clone())); + } + + if let Some(e) = else_expr { + operands.push(PyExpr::from(*e.clone())); + }; + + Ok(operands) + } + Expr::InList { expr, list, .. } => { + let mut operands: Vec = vec![PyExpr::from(*expr.clone())]; + for list_elem in list { + operands.push(PyExpr::from(list_elem.clone())); + } + + Ok(operands) + } + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![ + PyExpr::from(*left.clone()), + PyExpr::from(*right.clone()), + ]), + Expr::Like(Like { expr, pattern, .. }) => Ok(vec![ + PyExpr::from(*expr.clone()), + PyExpr::from(*pattern.clone()), + ]), + Expr::ILike(Like { expr, pattern, .. }) => Ok(vec![ + PyExpr::from(*expr.clone()), + PyExpr::from(*pattern.clone()), + ]), + Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![ + PyExpr::from(*expr.clone()), + PyExpr::from(*pattern.clone()), + ]), + Expr::Between(Between { + expr, + negated: _, + low, + high, + }) => Ok(vec![ + PyExpr::from(*expr.clone()), + PyExpr::from(*low.clone()), + PyExpr::from(*high.clone()), + ]), + + // Currently un-support/implemented Expr types for Rex Call operations + Expr::GroupingSet(..) + | Expr::OuterReferenceColumn(_, _) + | Expr::Wildcard + | Expr::QualifiedWildcard { .. } + | Expr::ScalarSubquery(..) + | Expr::Placeholder { .. } + | Expr::Exists { .. } => Err(py_runtime_err(format!( + "Unimplemented Expr type: {}", + self.expr + ))), + } + } + + /// Extracts the operator associated with a RexType::Call + pub fn rex_call_operator(&self) -> PyResult { + Ok(match &self.expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + op, + right: _, + }) => format!("{op}"), + Expr::ScalarFunction { fun, args: _ } => format!("{fun}"), + Expr::ScalarUDF { fun, .. } => fun.name.clone(), + Expr::Cast { .. } => "cast".to_string(), + Expr::Between { .. } => "between".to_string(), + Expr::Case { .. } => "case".to_string(), + Expr::IsNull(..) => "is null".to_string(), + Expr::IsNotNull(..) => "is not null".to_string(), + Expr::IsTrue(_) => "is true".to_string(), + Expr::IsFalse(_) => "is false".to_string(), + Expr::IsUnknown(_) => "is unknown".to_string(), + Expr::IsNotTrue(_) => "is not true".to_string(), + Expr::IsNotFalse(_) => "is not false".to_string(), + Expr::IsNotUnknown(_) => "is not unknown".to_string(), + Expr::InList { .. } => "in list".to_string(), + Expr::Negative(..) => "negative".to_string(), + Expr::Not(..) => "not".to_string(), + Expr::Like(Like { negated, .. }) => { + if *negated { + "not like".to_string() + } else { + "like".to_string() + } + } + Expr::ILike(Like { negated, .. }) => { + if *negated { + "not ilike".to_string() + } else { + "ilike".to_string() + } + } + Expr::SimilarTo(Like { negated, .. }) => { + if *negated { + "not similar to".to_string() + } else { + "similar to".to_string() + } + } + _ => { + return Err(py_type_err(format!( + "Catch all triggered in get_operator_name: {:?}", + &self.expr + ))) + } + }) + } + + pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult { + self._column_name(&plan.plan()).map_err(py_runtime_err) + } +} + +impl PyExpr { + pub fn _column_name(&self, plan: &LogicalPlan) -> Result { + let field = Self::expr_to_field(&self.expr, plan)?; + Ok(field.qualified_column().flat_name()) + } + + /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against + pub fn expr_to_field( + expr: &Expr, + input_plan: &LogicalPlan, + ) -> Result { + match expr { + Expr::Sort(Sort { expr, .. }) => { + // DataFusion does not support create_name for sort expressions (since they never + // appear in projections) so we just delegate to the contained expression instead + Self::expr_to_field(expr, input_plan) + } + _ => { + let fields = + exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?; + Ok(fields[0].clone()) + } + } + } + + fn _types(expr: &Expr) -> PyResult { + match expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + op, + right: _, + }) => match op { + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::And + | Operator::Or + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean), + Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => { + DataTypeMap::map_from_arrow_type(&DataType::Int64) + } + Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64), + Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8), + Operator::BitwiseShiftLeft + | Operator::BitwiseShiftRight + | Operator::BitwiseXor + | Operator::BitwiseAnd + | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary), + }, + Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type), + Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value), + _ => Err(py_type_err(format!( + "Non Expr::Literal encountered in types: {:?}", + expr + ))), + } + } } /// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> { - // expressions m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/expr/projection.rs b/src/expr/projection.rs index f5ba12db..b3296618 100644 --- a/src/expr/projection.rs +++ b/src/expr/projection.rs @@ -16,6 +16,7 @@ // under the License. use datafusion_expr::logical_plan::Projection; +use datafusion_expr::Expr; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; @@ -27,7 +28,7 @@ use crate::sql::logical::PyLogicalPlan; #[pyclass(name = "Projection", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyProjection { - projection: Projection, + pub projection: Projection, } impl PyProjection { @@ -92,6 +93,21 @@ impl PyProjection { } } +impl PyProjection { + /// Projection: Gets the names of the fields that should be projected + pub fn projected_expressions(local_expr: &PyExpr) -> Vec { + let mut projs: Vec = Vec::new(); + match &local_expr.expr { + Expr::Alias(expr, _name) => { + let py_expr: PyExpr = PyExpr::from(*expr.clone()); + projs.extend_from_slice(Self::projected_expressions(&py_expr).as_slice()); + } + _ => projs.push(local_expr.clone()), + } + projs + } +} + impl LogicalNode for PyProjection { fn inputs(&self) -> Vec { vec![PyLogicalPlan::from((*self.projection.input).clone())] diff --git a/src/sql/logical.rs b/src/sql/logical.rs index a75315d3..07a3f65b 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::errors::py_unsupported_variant_err; use crate::expr::aggregate::PyAggregate; use crate::expr::analyze::PyAnalyze; +use crate::expr::distinct::PyDistinct; use crate::expr::empty_relation::PyEmptyRelation; use crate::expr::explain::PyExplain; use crate::expr::extension::PyExtension; @@ -62,6 +63,7 @@ impl PyLogicalPlan { LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py), LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py), LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py), + LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py), LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py), LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py), LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),