Skip to content

Commit 3f6e2d6

Browse files
committed
[SPARK-48498][SQL][FOLLOWUP] do padding for char-char comparison
### What changes were proposed in this pull request? This is a followup of #46832 to handle a missing case: char-char comparison. We should pad both sides if `READ_SIDE_CHAR_PADDING` is not enabled. ### Why are the changes needed? bug fix if people disable read side char padding ### Does this PR introduce _any_ user-facing change? No because it's a followup and the original PR is not released yet ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #47412 from cloud-fan/char. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 1645046 commit 3f6e2d6

File tree

3 files changed

+35
-12
lines changed

3 files changed

+35
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,14 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils {
240240
* attributes. When comparing two char type columns/fields, we need to pad the shorter one to
241241
* the longer length.
242242
*/
243-
def addPaddingInStringComparison(attrs: Seq[Attribute]): Seq[Expression] = {
243+
def addPaddingInStringComparison(attrs: Seq[Attribute], alwaysPad: Boolean): Seq[Expression] = {
244244
val rawTypes = attrs.map(attr => getRawType(attr.metadata))
245245
if (rawTypes.exists(_.isEmpty)) {
246246
attrs
247247
} else {
248248
val typeWithTargetCharLength = rawTypes.map(_.get).reduce(typeWithWiderCharLength)
249249
attrs.zip(rawTypes.map(_.get)).map { case (attr, rawType) =>
250-
padCharToTargetLength(attr, rawType, typeWithTargetCharLength).getOrElse(attr)
250+
padCharToTargetLength(attr, rawType, typeWithTargetCharLength, alwaysPad).getOrElse(attr)
251251
}
252252
}
253253
}
@@ -270,9 +270,10 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils {
270270
private def padCharToTargetLength(
271271
expr: Expression,
272272
rawType: DataType,
273-
typeWithTargetCharLength: DataType): Option[Expression] = {
273+
typeWithTargetCharLength: DataType,
274+
alwaysPad: Boolean): Option[Expression] = {
274275
(rawType, typeWithTargetCharLength) match {
275-
case (CharType(len), CharType(target)) if target > len =>
276+
case (CharType(len), CharType(target)) if alwaysPad || target > len =>
276277
Some(StringRPad(expr, Literal(target)))
277278

278279
case (StructType(fields), StructType(targets)) =>
@@ -283,7 +284,8 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils {
283284
while (i < fields.length) {
284285
val field = fields(i)
285286
val fieldExpr = GetStructField(expr, i, Some(field.name))
286-
val padded = padCharToTargetLength(fieldExpr, field.dataType, targets(i).dataType)
287+
val padded = padCharToTargetLength(
288+
fieldExpr, field.dataType, targets(i).dataType, alwaysPad)
287289
needPadding = padded.isDefined
288290
createStructExprs += Literal(field.name)
289291
createStructExprs += padded.getOrElse(fieldExpr)
@@ -293,7 +295,7 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils {
293295

294296
case (ArrayType(et, containsNull), ArrayType(target, _)) =>
295297
val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull)
296-
padCharToTargetLength(param, et, target).map { padded =>
298+
padCharToTargetLength(param, et, target, alwaysPad).map { padded =>
297299
val func = LambdaFunction(padded, Seq(param))
298300
ArrayTransform(expr, func)
299301
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
137137
case (_, _: OuterReference) => Seq(right)
138138
case _ => Nil
139139
}
140-
val newChildren = CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))
140+
val newChildren = CharVarcharUtils.addPaddingInStringComparison(
141+
Seq(left, right), padCharCol)
141142
if (outerRefs.nonEmpty) {
142143
b.withNewChildren(newChildren.map(_.transform {
143144
case a: Attribute if outerRefs.exists(_.semanticEquals(a)) => OuterReference(a)
@@ -148,7 +149,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
148149

149150
case i @ In(e @ AttrOrOuterRef(attr), list) if list.forall(_.isInstanceOf[Attribute]) =>
150151
val newChildren = CharVarcharUtils.addPaddingInStringComparison(
151-
attr +: list.map(_.asInstanceOf[Attribute]))
152+
attr +: list.map(_.asInstanceOf[Attribute]), padCharCol)
152153
if (e.isInstanceOf[OuterReference]) {
153154
i.copy(
154155
value = newChildren.head.transform {

sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -960,25 +960,45 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
960960
import testImplicits._
961961
withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") {
962962
withTempPath { dir =>
963-
withTable("t") {
963+
withTable("t1", "t2") {
964964
Seq(
965965
"12" -> "12",
966966
"12" -> "12 ",
967967
"12 " -> "12",
968968
"12 " -> "12 "
969969
).toDF("c1", "c2").write.format(format).save(dir.toString)
970-
sql(s"CREATE TABLE t (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'")
970+
971+
sql(s"CREATE TABLE t1 (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'")
971972
// Comparing CHAR column with STRING column directly compares the stored value.
972973
checkAnswer(
973-
sql("SELECT c1 = c2 FROM t"),
974+
sql("SELECT c1 = c2 FROM t1"),
975+
Seq(Row(true), Row(false), Row(false), Row(true))
976+
)
977+
checkAnswer(
978+
sql("SELECT c1 IN (c2) FROM t1"),
974979
Seq(Row(true), Row(false), Row(false), Row(true))
975980
)
976981
// No matter the CHAR type value is padded or not in the storage, we should always pad it
977982
// before comparison with STRING literals.
978983
checkAnswer(
979-
sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t WHERE c2 = '12'"),
984+
sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t1 WHERE c2 = '12'"),
980985
Seq(Row(true, true, true), Row(true, true, true))
981986
)
987+
checkAnswer(
988+
sql("SELECT c1 IN ('12'), c1 IN ('12 '), c1 IN ('12 ') FROM t1 WHERE c2 = '12'"),
989+
Seq(Row(true, true, true), Row(true, true, true))
990+
)
991+
992+
sql(s"CREATE TABLE t2 (c1 CHAR(3), c2 CHAR(5)) USING $format LOCATION '$dir'")
993+
// Comparing CHAR column with CHAR column compares the padded values.
994+
checkAnswer(
995+
sql("SELECT c1 = c2, c2 = c1 FROM t2"),
996+
Seq(Row(true, true), Row(true, true), Row(true, true), Row(true, true))
997+
)
998+
checkAnswer(
999+
sql("SELECT c1 IN (c2), c2 IN (c1) FROM t2"),
1000+
Seq(Row(true, true), Row(true, true), Row(true, true), Row(true, true))
1001+
)
9821002
}
9831003
}
9841004
}

0 commit comments

Comments
 (0)