Skip to content

Commit

Permalink
Expand Expr to include RexType basic support (#378)
Browse files Browse the repository at this point in the history
* Make expr member of PyExpr public

* Add RexType to Expr

* Add utility functions for mapping ScalarValue instances to DataTypeMap instances

* Add function to get python_value from Expr instance

* Fix syntax problems

* Add function to get the operands for a Rex::Call

* Add function to get operator for RexType::Call

* expand types function to include variant support for BinaryExpr

* Add variant coverage for Decimal128 and Decimal256

* add function for getting the column name of an Expr from a LogicalPlan

* Make PyProjection::projection member public

* Add projected_expressions to projection node

* Adjust function signature

* Add Distinct variant to to_variant function in PyLogicalPlan

* Fill in variants for DataType::Timestamp

* Address syntax issues

* Refactor types() function to extend support for CAST

* Update CAST variant handling

* Cargo fmt

* Cargo clippy

* Coverage for INTERVAL in DataType

* More cargo fmt changes
  • Loading branch information
jdye64 authored May 10, 2023
1 parent 43b3105 commit 433dbca
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 25 deletions.
119 changes: 99 additions & 20 deletions src/common/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::datatypes::DataType;
use datafusion_common::DataFusionError;
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datafusion_common::{DataFusionError, ScalarValue};
use pyo3::prelude::*;

use crate::errors::py_datafusion_err;
Expand Down Expand Up @@ -130,9 +130,11 @@ impl DataTypeMap {
PythonType::Float,
SqlType::FLOAT,
)),
DataType::Timestamp(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
DataType::Timestamp(unit.clone(), tz.clone()),
PythonType::Datetime,
SqlType::DATE,
)),
DataType::Date32 => Ok(DataTypeMap::new(
DataType::Date32,
PythonType::Datetime,
Expand All @@ -143,18 +145,28 @@ impl DataTypeMap {
PythonType::Datetime,
SqlType::DATE,
)),
DataType::Time32(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Time64(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Time32(unit) => Ok(DataTypeMap::new(
DataType::Time32(unit.clone()),
PythonType::Datetime,
SqlType::DATE,
)),
DataType::Time64(unit) => Ok(DataTypeMap::new(
DataType::Time64(unit.clone()),
PythonType::Datetime,
SqlType::DATE,
)),
DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Interval(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
DataType::Interval(interval_unit.clone()),
PythonType::Datetime,
match interval_unit {
IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
},
)),
DataType::Binary => Ok(DataTypeMap::new(
DataType::Binary,
PythonType::Bytes,
Expand Down Expand Up @@ -197,12 +209,16 @@ impl DataTypeMap {
DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Decimal128(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Decimal256(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
DataType::Decimal128(*precision, *scale),
PythonType::Float,
SqlType::DECIMAL,
)),
DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
DataType::Decimal256(*precision, *scale),
PythonType::Float,
SqlType::DECIMAL,
)),
DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
Expand All @@ -211,6 +227,69 @@ impl DataTypeMap {
)),
}
}

/// Generate the `DataTypeMap` from a `ScalarValue` instance
pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
}

/// Maps a `ScalarValue` to an Arrow `DataType`
pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
match scalar_val {
ScalarValue::Boolean(_) => Ok(DataType::Boolean),
ScalarValue::Float32(_) => Ok(DataType::Float32),
ScalarValue::Float64(_) => Ok(DataType::Float64),
ScalarValue::Decimal128(_, precision, scale) => {
Ok(DataType::Decimal128(*precision, *scale))
}
ScalarValue::Dictionary(data_type, scalar_type) => {
// Call this function again to map the dictionary scalar_value to an Arrow type
Ok(DataType::Dictionary(
Box::new(*data_type.clone()),
Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
))
}
ScalarValue::Int8(_) => Ok(DataType::Int8),
ScalarValue::Int16(_) => Ok(DataType::Int16),
ScalarValue::Int32(_) => Ok(DataType::Int32),
ScalarValue::Int64(_) => Ok(DataType::Int64),
ScalarValue::UInt8(_) => Ok(DataType::UInt8),
ScalarValue::UInt16(_) => Ok(DataType::UInt16),
ScalarValue::UInt32(_) => Ok(DataType::UInt32),
ScalarValue::UInt64(_) => Ok(DataType::UInt64),
ScalarValue::Utf8(_) => Ok(DataType::Utf8),
ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
ScalarValue::Binary(_) => Ok(DataType::Binary),
ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
ScalarValue::Date32(_) => Ok(DataType::Date32),
ScalarValue::Date64(_) => Ok(DataType::Date64),
ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
ScalarValue::Null => Ok(DataType::Null),
ScalarValue::TimestampSecond(_, tz) => {
Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
}
ScalarValue::TimestampMillisecond(_, tz) => {
Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
}
ScalarValue::TimestampMicrosecond(_, tz) => {
Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
}
ScalarValue::TimestampNanosecond(_, tz) => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
}
ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
ScalarValue::IntervalMonthDayNano(..) => {
Ok(DataType::Interval(IntervalUnit::MonthDayNano))
}
ScalarValue::List(_val, field_ref) => Ok(DataType::List(field_ref.to_owned())),
ScalarValue::Struct(_, fields) => Ok(DataType::Struct(fields.to_owned())),
ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
}
}
}

#[pymethods]
Expand Down
Loading

0 comments on commit 433dbca

Please sign in to comment.