@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
24
24
import org .apache .spark .sql .catalyst .trees .TreePattern .{BINARY_COMPARISON , IN }
25
25
import org .apache .spark .sql .catalyst .util .CharVarcharUtils
26
26
import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Relation
27
+ import org .apache .spark .sql .internal .SQLConf
27
28
import org .apache .spark .sql .types .{CharType , Metadata , StringType }
28
29
import org .apache .spark .unsafe .types .UTF8String
29
30
@@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
66
67
r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)
67
68
})
68
69
}
69
- paddingForStringComparison(newPlan)
70
+ paddingForStringComparison(newPlan, padCharCol = false )
70
71
} else {
71
- paddingForStringComparison(plan)
72
+ paddingForStringComparison(
73
+ plan, padCharCol = ! conf.getConf(SQLConf .LEGACY_NO_CHAR_PADDING_IN_PREDICATE ))
72
74
}
73
75
}
74
76
@@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
90
92
}
91
93
}
92
94
93
- private def paddingForStringComparison (plan : LogicalPlan ): LogicalPlan = {
95
+ private def paddingForStringComparison (plan : LogicalPlan , padCharCol : Boolean ): LogicalPlan = {
94
96
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON , IN )) {
95
97
case operator => operator.transformExpressionsUpWithPruning(
96
98
_.containsAnyPattern(BINARY_COMPARISON , IN )) {
@@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
99
101
// String literal is treated as char type when it's compared to a char type column.
100
102
// We should pad the shorter one to the longer length.
101
103
case b @ BinaryComparison (e @ AttrOrOuterRef (attr), lit) if lit.foldable =>
102
- padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
104
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
103
105
b.withNewChildren(newChildren)
104
106
}.getOrElse(b)
105
107
106
108
case b @ BinaryComparison (lit, e @ AttrOrOuterRef (attr)) if lit.foldable =>
107
- padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
109
+ padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
108
110
b.withNewChildren(newChildren.reverse)
109
111
}.getOrElse(b)
110
112
@@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
117
119
val literalCharLengths = literalChars.map(_.numChars())
118
120
val targetLen = (length +: literalCharLengths).max
119
121
Some (i.copy(
120
- value = addPadding(e, length, targetLen),
122
+ value = addPadding(e, length, targetLen, alwaysPad = padCharCol ),
121
123
list = list.zip(literalCharLengths).map {
122
- case (lit, charLength) => addPadding(lit, charLength, targetLen)
124
+ case (lit, charLength) =>
125
+ addPadding(lit, charLength, targetLen, alwaysPad = false )
123
126
} ++ nulls.map(Literal .create(_, StringType ))))
124
127
case _ => None
125
128
}.getOrElse(i)
@@ -162,6 +165,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
162
165
private def padAttrLitCmp (
163
166
expr : Expression ,
164
167
metadata : Metadata ,
168
+ padCharCol : Boolean ,
165
169
lit : Expression ): Option [Seq [Expression ]] = {
166
170
if (expr.dataType == StringType ) {
167
171
CharVarcharUtils .getRawType(metadata).flatMap {
@@ -174,7 +178,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
174
178
if (length < stringLitLen) {
175
179
Some (Seq (StringRPad (expr, Literal (stringLitLen)), lit))
176
180
} else if (length > stringLitLen) {
177
- Some (Seq (expr, StringRPad (lit, Literal (length))))
181
+ val paddedExpr = if (padCharCol) {
182
+ StringRPad (expr, Literal (length))
183
+ } else {
184
+ expr
185
+ }
186
+ Some (Seq (paddedExpr, StringRPad (lit, Literal (length))))
187
+ } else if (padCharCol) {
188
+ Some (Seq (StringRPad (expr, Literal (length)), lit))
178
189
} else {
179
190
None
180
191
}
@@ -186,7 +197,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
186
197
}
187
198
}
188
199
189
- private def addPadding (expr : Expression , charLength : Int , targetLength : Int ): Expression = {
190
- if (targetLength > charLength) StringRPad (expr, Literal (targetLength)) else expr
200
+ private def addPadding (
201
+ expr : Expression ,
202
+ charLength : Int ,
203
+ targetLength : Int ,
204
+ alwaysPad : Boolean ): Expression = {
205
+ if (targetLength > charLength) {
206
+ StringRPad (expr, Literal (targetLength))
207
+ } else if (alwaysPad) {
208
+ StringRPad (expr, Literal (charLength))
209
+ } else expr
191
210
}
192
211
}
0 commit comments