Skip to content

Commit 1c58fa9

Browse files
ericlrxin
authored andcommitted
[SPARK-16514][SQL] Fix various regex codegen bugs
## What changes were proposed in this pull request? RegexExtract and RegexReplace currently crash on non-nullable input due use of a hard-coded local variable name (e.g. compiles fail with `java.lang.Exception: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 85, Column 26: Redefinition of local variable "m" `). This changes those variables to use fresh names, and also in a few other places. ## How was this patch tested? Unit tests. rxin Author: Eric Liang <ekl@databricks.com> Closes apache#14168 from ericl/sc-3906.
1 parent 56bd399 commit 1c58fa9

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ case class Like(left: Expression, right: Expression)
108108
""")
109109
}
110110
} else {
111+
val rightStr = ctx.freshName("rightStr")
111112
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
112113
s"""
113-
String rightStr = ${eval2}.toString();
114-
${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr));
114+
String $rightStr = ${eval2}.toString();
115+
${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
115116
${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
116117
"""
117118
})
@@ -157,10 +158,11 @@ case class RLike(left: Expression, right: Expression)
157158
""")
158159
}
159160
} else {
161+
val rightStr = ctx.freshName("rightStr")
160162
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
161163
s"""
162-
String rightStr = ${eval2}.toString();
163-
${patternClass} $pattern = ${patternClass}.compile(rightStr);
164+
String $rightStr = ${eval2}.toString();
165+
${patternClass} $pattern = ${patternClass}.compile($rightStr);
164166
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
165167
"""
166168
})
@@ -259,6 +261,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
259261
val classNamePattern = classOf[Pattern].getCanonicalName
260262
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
261263

264+
val matcher = ctx.freshName("matcher")
265+
262266
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
263267
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
264268
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
@@ -267,6 +271,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
267271
ctx.addMutableState(classNameStringBuffer,
268272
termResult, s"${termResult} = new $classNameStringBuffer();")
269273

274+
val setEvNotNull = if (nullable) {
275+
s"${ev.isNull} = false;"
276+
} else {
277+
""
278+
}
279+
270280
nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
271281
s"""
272282
if (!$regexp.equals(${termLastRegex})) {
@@ -280,14 +290,14 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
280290
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
281291
}
282292
${termResult}.delete(0, ${termResult}.length());
283-
java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString());
293+
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
284294

285-
while (m.find()) {
286-
m.appendReplacement(${termResult}, ${termLastReplacement});
295+
while (${matcher}.find()) {
296+
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
287297
}
288-
m.appendTail(${termResult});
298+
${matcher}.appendTail(${termResult});
289299
${ev.value} = UTF8String.fromString(${termResult}.toString());
290-
${ev.isNull} = false;
300+
$setEvNotNull
291301
"""
292302
})
293303
}
@@ -334,26 +344,34 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
334344
val termLastRegex = ctx.freshName("lastRegex")
335345
val termPattern = ctx.freshName("pattern")
336346
val classNamePattern = classOf[Pattern].getCanonicalName
347+
val matcher = ctx.freshName("matcher")
348+
val matchResult = ctx.freshName("matchResult")
337349

338350
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
339351
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
340352

353+
val setEvNotNull = if (nullable) {
354+
s"${ev.isNull} = false;"
355+
} else {
356+
""
357+
}
358+
341359
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
342360
s"""
343361
if (!$regexp.equals(${termLastRegex})) {
344362
// regex value changed
345363
${termLastRegex} = $regexp.clone();
346364
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
347365
}
348-
java.util.regex.Matcher m =
366+
java.util.regex.Matcher ${matcher} =
349367
${termPattern}.matcher($subject.toString());
350-
if (m.find()) {
351-
java.util.regex.MatchResult mr = m.toMatchResult();
352-
${ev.value} = UTF8String.fromString(mr.group($idx));
353-
${ev.isNull} = false;
368+
if (${matcher}.find()) {
369+
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
370+
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
371+
$setEvNotNull
354372
} else {
355373
${ev.value} = UTF8String.EMPTY_UTF8;
356-
${ev.isNull} = false;
374+
$setEvNotNull
357375
}"""
358376
})
359377
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
631631
checkEvaluation(expr, null, row4)
632632
checkEvaluation(expr, null, row5)
633633
checkEvaluation(expr, null, row6)
634+
635+
val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num"))
636+
checkEvaluation(nonNullExpr, "num-num", row1)
634637
}
635638

636639
test("RegexExtract") {
@@ -657,6 +660,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
657660

658661
val expr1 = new RegExpExtract(s, p)
659662
checkEvaluation(expr1, "100", row1)
663+
664+
val nonNullExpr = RegExpExtract(Literal("100-200"), Literal("(\\d+)-(\\d+)"), Literal(1))
665+
checkEvaluation(nonNullExpr, "100", row1)
660666
}
661667

662668
test("SPLIT") {

0 commit comments

Comments
 (0)