Skip to content

Commit 457d9d1

Browse files
fix: Optimize read_side_padding (apache#772)
## Which issue does this PR close? ## Rationale for this change This PR improves read_side_padding that is used for CHAR() schema ## What changes are included in this PR? Optimized spark_read_side_padding ## How are these changes tested? Added tests
1 parent c4bd3db commit 457d9d1

File tree

9 files changed

+71
-44
lines changed

9 files changed

+71
-44
lines changed

native/Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/src/execution/datafusion/expressions/comet_scalar_funcs.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ use datafusion_comet_spark_expr::scalar_funcs::hash_expressions::{
2121
};
2222
use datafusion_comet_spark_expr::scalar_funcs::{
2323
spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
24-
spark_murmur3_hash, spark_round, spark_rpad, spark_unhex, spark_unscaled_value, spark_xxhash64,
25-
SparkChrFunc,
24+
spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value,
25+
spark_xxhash64, SparkChrFunc,
2626
};
2727
use datafusion_common::{DataFusionError, Result as DataFusionResult};
2828
use datafusion_expr::registry::FunctionRegistry;
@@ -67,9 +67,9 @@ pub fn create_comet_physical_fun(
6767
"floor" => {
6868
make_comet_scalar_udf!("floor", spark_floor, data_type)
6969
}
70-
"rpad" => {
71-
let func = Arc::new(spark_rpad);
72-
make_comet_scalar_udf!("rpad", func, without data_type)
70+
"read_side_padding" => {
71+
let func = Arc::new(spark_read_side_padding);
72+
make_comet_scalar_udf!("read_side_padding", func, without data_type)
7373
}
7474
"round" => {
7575
make_comet_scalar_udf!("round", spark_round, data_type)

native/core/src/execution/datafusion/planner.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,11 +1724,16 @@ impl PhysicalPlanner {
17241724

17251725
let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
17261726
Some(t) => t,
1727-
None => self
1728-
.session_ctx
1729-
.udf(fun_name)?
1730-
.inner()
1731-
.return_type(&input_expr_types)?,
1727+
None => {
1728+
let fun_name = match fun_name.as_str() {
1729+
"read_side_padding" => "rpad", // use the same return type as rpad
1730+
other => other,
1731+
};
1732+
self.session_ctx
1733+
.udf(fun_name)?
1734+
.inner()
1735+
.return_type(&input_expr_types)?
1736+
}
17321737
};
17331738

17341739
let fun_expr =

native/spark-expr/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ chrono-tz = { workspace = true }
4141
num = { workspace = true }
4242
regex = { workspace = true }
4343
thiserror = { workspace = true }
44-
unicode-segmentation = "1.11.0"
4544

4645
[dev-dependencies]
4746
arrow-data = {workspace = true}

native/spark-expr/src/scalar_funcs.rs

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::{cmp::min, sync::Arc};
19-
2018
use arrow::{
2119
array::{
22-
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray,
23-
Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
20+
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
21+
Int64Array, Int64Builder, Int8Array, OffsetSizeTrait,
2422
},
2523
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
2624
};
25+
use arrow_array::builder::GenericStringBuilder;
2726
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
2827
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
2928
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
@@ -35,7 +34,8 @@ use num::{
3534
integer::{div_ceil, div_floor},
3635
BigInt, Signed, ToPrimitive,
3736
};
38-
use unicode_segmentation::UnicodeSegmentation;
37+
use std::fmt::Write;
38+
use std::{cmp::min, sync::Arc};
3939

4040
mod unhex;
4141
pub use unhex::spark_unhex;
@@ -387,52 +387,54 @@ pub fn spark_round(
387387
}
388388

389389
/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length
390-
pub fn spark_rpad(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
390+
pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
391391
match args {
392392
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => {
393-
match args[0].data_type() {
394-
DataType::Utf8 => spark_rpad_internal::<i32>(array, *length),
395-
DataType::LargeUtf8 => spark_rpad_internal::<i64>(array, *length),
393+
match array.data_type() {
394+
DataType::Utf8 => spark_read_side_padding_internal::<i32>(array, *length),
395+
DataType::LargeUtf8 => spark_read_side_padding_internal::<i64>(array, *length),
396396
// TODO: handle Dictionary types
397397
other => Err(DataFusionError::Internal(format!(
398-
"Unsupported data type {other:?} for function rpad",
398+
"Unsupported data type {other:?} for function read_side_padding",
399399
))),
400400
}
401401
}
402402
other => Err(DataFusionError::Internal(format!(
403-
"Unsupported arguments {other:?} for function rpad",
403+
"Unsupported arguments {other:?} for function read_side_padding",
404404
))),
405405
}
406406
}
407407

408-
fn spark_rpad_internal<T: OffsetSizeTrait>(
408+
fn spark_read_side_padding_internal<T: OffsetSizeTrait>(
409409
array: &ArrayRef,
410410
length: i32,
411411
) -> Result<ColumnarValue, DataFusionError> {
412412
let string_array = as_generic_string_array::<T>(array)?;
413+
let length = 0.max(length) as usize;
414+
let space_string = " ".repeat(length);
415+
416+
let mut builder =
417+
GenericStringBuilder::<T>::with_capacity(string_array.len(), string_array.len() * length);
413418

414-
let result = string_array
415-
.iter()
416-
.map(|string| match string {
419+
for string in string_array.iter() {
420+
match string {
417421
Some(string) => {
418-
let length = if length < 0 { 0 } else { length as usize };
419-
if length == 0 {
420-
Ok(Some("".to_string()))
422+
// It looks Spark's UTF8String is closer to chars rather than graphemes
423+
// https://stackoverflow.com/a/46290728
424+
let char_len = string.chars().count();
425+
if length <= char_len {
426+
builder.append_value(string);
421427
} else {
422-
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
423-
if length < graphemes.len() {
424-
Ok(Some(string.to_string()))
425-
} else {
426-
let mut s = string.to_string();
427-
s.push_str(" ".repeat(length - graphemes.len()).as_str());
428-
Ok(Some(s))
429-
}
428+
// write_str updates only the value buffer, not null nor offset buffer
429+
// This is convenient for concatenating str(s)
430+
builder.write_str(string)?;
431+
builder.append_value(&space_string[char_len..]);
430432
}
431433
}
432-
_ => Ok(None),
433-
})
434-
.collect::<Result<GenericStringArray<T>, DataFusionError>>()?;
435-
Ok(ColumnarValue::Array(Arc::new(result)))
434+
_ => builder.append_null(),
435+
}
436+
}
437+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
436438
}
437439

438440
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2178,7 +2178,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
21782178
}
21792179

21802180
// With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for
2181-
// char types. Use rpad to achieve the behavior.
2181+
// char types.
21822182
// See https://github.com/apache/spark/pull/38151
21832183
case s: StaticInvoke
21842184
if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
@@ -2194,7 +2194,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
21942194

21952195
if (argsExpr.forall(_.isDefined)) {
21962196
val builder = ExprOuterClass.ScalarFunc.newBuilder()
2197-
builder.setFunc("rpad")
2197+
builder.setFunc("read_side_padding")
21982198
argsExpr.foreach(arg => builder.addArgs(arg.get))
21992199

22002200
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
SELECT
2+
cd_gender
3+
FROM customer_demographics
4+
WHERE
5+
cd_gender = 'M' AND
6+
cd_marital_status = 'S' AND
7+
cd_education_status = 'College'

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19111911
}
19121912
}
19131913

1914+
test("readSidePadding") {
1915+
// https://stackoverflow.com/a/46290728
1916+
val table = "test"
1917+
withTable(table) {
1918+
sql(s"create table $table(col1 CHAR(2)) using parquet")
1919+
sql(s"insert into $table values('é')") // unicode 'e\\u{301}'
1920+
sql(s"insert into $table values('é')") // unicode '\\u{e9}'
1921+
sql(s"insert into $table values('')")
1922+
sql(s"insert into $table values('ab')")
1923+
1924+
checkSparkAnswerAndOperator(s"SELECT * FROM $table")
1925+
}
1926+
}
1927+
19141928
test("isnan") {
19151929
Seq("true", "false").foreach { dictionary =>
19161930
withSQLConf("parquet.enable.dictionary" -> dictionary) {

spark/src/test/scala/org/apache/spark/sql/benchmark/CometTPCDSMicroBenchmark.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ object CometTPCDSMicroBenchmark extends CometTPCQueryBenchmarkBase {
6363
"agg_sum_integers_no_grouping",
6464
"case_when_column_or_null",
6565
"case_when_scalar",
66+
"char_type",
6667
"filter_highly_selective",
6768
"filter_less_selective",
6869
"if_column_or_null",

0 commit comments

Comments
 (0)