Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray,
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
Expand Down Expand Up @@ -1100,6 +1100,7 @@ fn cast_array(
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, cast_options, tz)?),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
Expand All @@ -1118,6 +1119,50 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}

fn cast_date_to_timestamp(
array_ref: &ArrayRef,
cast_options: &SparkCastOptions,
target_tz: &Option<Arc<str>>,
) -> SparkResult<ArrayRef> {
let tz_str = if cast_options.timezone.is_empty() {
"UTC"
} else {
cast_options.timezone.as_str()
};
// safe to unwrap since we are falling back to UTC above
let tz = timezone::Tz::from_str(tz_str)?;
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let date_array = array_ref.as_primitive::<Date32Type>();

let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());

for date in date_array.iter() {
match date {
Some(date) => {
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
// number of years possible with days as i32 (~ 6 mil yrs)
// convert date in session timezone to timestamp in UTC
let naive_date = epoch + chrono::Duration::days(date as i64);
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
let local_midnight_in_microsec = tz
.from_local_datetime(&local_midnight)
// return earliest possible time (edge case with spring / fall DST changes)
.earliest()
.map(|dt| dt.timestamp_micros())
// in case there is an issue with DST and returns None , we fall back to UTC
.unwrap_or((date as i64) * 86_400 * 1_000_000);
builder.append_value(local_midnight_in_microsec);
}
None => {
builder.append_null();
}
}
}
Ok(Arc::new(
builder.finish().with_timezone_opt(target_tz.clone()),
))
}

fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
}
}
Compatible()
case (DataTypes.DateType, toType) => canCastFromDate(toType)
case _ => unsupported(fromType, toType)
}
}
Expand Down Expand Up @@ -344,6 +345,12 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported"))
}

private def canCastFromDate(toType: DataType): SupportLevel = toType match {
case DataTypes.TimestampType =>
Compatible()
case _ => Unsupported(Some(s"Cast from DateType to $toType is not supported"))
}

private def unsupported(fromType: DataType, toType: DataType): Unsupported = {
Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
}
Expand Down
3 changes: 1 addition & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -989,8 +989,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDates(), DataTypes.StringType)
}

ignore("cast DateType to TimestampType") {
// Arrow error: Cast error: Casting from Date32 to Timestamp(Microsecond, Some("UTC")) not supported
test("cast DateType to TimestampType") {
castTest(generateDates(), DataTypes.TimestampType)
}

Expand Down
Loading