Skip to content

Commit ad46821

Browse files
committed
Merge remote-tracking branch 'apache/main' into comet-parquet-exec
2 parents 16033d9 + 712658e commit ad46821

File tree

4 files changed

+298
-166
lines changed

4 files changed

+298
-166
lines changed

native/spark-expr/src/cast.rs

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use arrow::{
3434
};
3535
use arrow_array::builder::StringBuilder;
3636
use arrow_array::{DictionaryArray, StringArray, StructArray};
37-
use arrow_schema::{DataType, Schema};
37+
use arrow_schema::{DataType, Field, Schema};
3838
use datafusion_common::{
3939
cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
4040
};
@@ -714,6 +714,14 @@ fn cast_array(
714714
(DataType::Struct(_), DataType::Utf8) => {
715715
Ok(casts_struct_to_string(array.as_struct(), &timezone)?)
716716
}
717+
(DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct(
718+
array.as_struct(),
719+
from_type,
720+
to_type,
721+
eval_mode,
722+
timezone,
723+
allow_incompat,
724+
)?),
717725
_ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => {
718726
// use DataFusion cast only when we know that it is compatible with Spark
719727
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
@@ -811,6 +819,35 @@ fn is_datafusion_spark_compatible(
811819
}
812820
}
813821

822+
/// Cast between struct types based on logic in
823+
/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
824+
fn cast_struct_to_struct(
825+
array: &StructArray,
826+
from_type: &DataType,
827+
to_type: &DataType,
828+
eval_mode: EvalMode,
829+
timezone: String,
830+
allow_incompat: bool,
831+
) -> DataFusionResult<ArrayRef> {
832+
match (from_type, to_type) {
833+
(DataType::Struct(_), DataType::Struct(to_fields)) => {
834+
let mut cast_fields: Vec<(Arc<Field>, ArrayRef)> = Vec::with_capacity(to_fields.len());
835+
for i in 0..to_fields.len() {
836+
let cast_field = cast_array(
837+
Arc::clone(array.column(i)),
838+
to_fields[i].data_type(),
839+
eval_mode,
840+
timezone.clone(),
841+
allow_incompat,
842+
)?;
843+
cast_fields.push((Arc::clone(&to_fields[i]), cast_field));
844+
}
845+
Ok(Arc::new(StructArray::from(cast_fields)))
846+
}
847+
_ => unreachable!(),
848+
}
849+
}
850+
814851
fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult<ArrayRef> {
815852
// cast each field to a string
816853
let string_arrays: Vec<ArrayRef> = array
@@ -1929,7 +1966,7 @@ fn trim_end(s: &str) -> &str {
19291966
mod tests {
19301967
use arrow::datatypes::TimestampMicrosecondType;
19311968
use arrow_array::StringArray;
1932-
use arrow_schema::{Field, TimeUnit};
1969+
use arrow_schema::{Field, Fields, TimeUnit};
19331970
use std::str::FromStr;
19341971

19351972
use super::*;
@@ -2336,4 +2373,75 @@ mod tests {
23362373
assert_eq!(r#"{4, d}"#, string_array.value(3));
23372374
assert_eq!(r#"{5, e}"#, string_array.value(4));
23382375
}
2376+
2377+
#[test]
2378+
fn test_cast_struct_to_struct() {
2379+
let a: ArrayRef = Arc::new(Int32Array::from(vec![
2380+
Some(1),
2381+
Some(2),
2382+
None,
2383+
Some(4),
2384+
Some(5),
2385+
]));
2386+
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2387+
let c: ArrayRef = Arc::new(StructArray::from(vec![
2388+
(Arc::new(Field::new("a", DataType::Int32, true)), a),
2389+
(Arc::new(Field::new("b", DataType::Utf8, true)), b),
2390+
]));
2391+
// change type of "a" from Int32 to Utf8
2392+
let fields = Fields::from(vec![
2393+
Field::new("a", DataType::Utf8, true),
2394+
Field::new("b", DataType::Utf8, true),
2395+
]);
2396+
let cast_array = spark_cast(
2397+
ColumnarValue::Array(c),
2398+
&DataType::Struct(fields),
2399+
EvalMode::Legacy,
2400+
"UTC",
2401+
false,
2402+
)
2403+
.unwrap();
2404+
if let ColumnarValue::Array(cast_array) = cast_array {
2405+
assert_eq!(5, cast_array.len());
2406+
let a = cast_array.as_struct().column(0).as_string::<i32>();
2407+
assert_eq!("1", a.value(0));
2408+
} else {
2409+
unreachable!()
2410+
}
2411+
}
2412+
2413+
#[test]
2414+
fn test_cast_struct_to_struct_drop_column() {
2415+
let a: ArrayRef = Arc::new(Int32Array::from(vec![
2416+
Some(1),
2417+
Some(2),
2418+
None,
2419+
Some(4),
2420+
Some(5),
2421+
]));
2422+
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
2423+
let c: ArrayRef = Arc::new(StructArray::from(vec![
2424+
(Arc::new(Field::new("a", DataType::Int32, true)), a),
2425+
(Arc::new(Field::new("b", DataType::Utf8, true)), b),
2426+
]));
2427+
// change type of "a" from Int32 to Utf8 and drop "b"
2428+
let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
2429+
let cast_array = spark_cast(
2430+
ColumnarValue::Array(c),
2431+
&DataType::Struct(fields),
2432+
EvalMode::Legacy,
2433+
"UTC",
2434+
false,
2435+
)
2436+
.unwrap();
2437+
if let ColumnarValue::Array(cast_array) = cast_array {
2438+
assert_eq!(5, cast_array.len());
2439+
let struct_array = cast_array.as_struct();
2440+
assert_eq!(1, struct_array.columns().len());
2441+
let a = struct_array.column(0).as_string::<i32>();
2442+
assert_eq!("1", a.value(0));
2443+
} else {
2444+
unreachable!()
2445+
}
2446+
}
23392447
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ object CometCast {
9595
canCastFromFloat(toType)
9696
case (DataTypes.DoubleType, _) =>
9797
canCastFromDouble(toType)
98+
case (from_struct: StructType, to_struct: StructType) =>
99+
from_struct.fields.zip(to_struct.fields).foreach { case (a, b) =>
100+
isSupported(a.dataType, b.dataType, timeZoneId, evalMode) match {
101+
case Compatible(_) =>
102+
// all good
103+
case other =>
104+
return other
105+
}
106+
}
107+
Compatible()
98108
case _ => Unsupported
99109
}
100110
}

0 commit comments

Comments
 (0)