Skip to content

Commit 1a5e762

Browse files
committed
[SPARK-16409][SQL] regexp_extract with optional groups causes NPE
## What changes were proposed in this pull request? regexp_extract actually returns null when it shouldn't when a regex matches but the requested optional group did not. This makes it return an empty string, as apparently designed. ## How was this patch tested? Additional unit test Author: Sean Owen <sowen@cloudera.com> Closes #14504 from srowen/SPARK-16409. (cherry picked from commit 8d87252) Signed-off-by: Sean Owen <sowen@cloudera.com>
1 parent c162886 commit 1a5e762

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

python/pyspark/sql/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,9 @@ def regexp_extract(str, pattern, idx):
12991299
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
13001300
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
13011301
[Row(d=u'100')]
1302+
>>> df = spark.createDataFrame([('aaaac',)], ['str'])
1303+
>>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()
1304+
[Row(d=u'')]
13021305
"""
13031306
sc = SparkContext._active_spark_context
13041307
jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
315315
val m = pattern.matcher(s.toString)
316316
if (m.find) {
317317
val mr: MatchResult = m.toMatchResult
318-
UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
318+
val group = mr.group(r.asInstanceOf[Int])
319+
if (group == null) { // Pattern matched, but not optional group
320+
UTF8String.EMPTY_UTF8
321+
} else {
322+
UTF8String.fromString(group)
323+
}
319324
} else {
320325
UTF8String.EMPTY_UTF8
321326
}
@@ -353,7 +358,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
353358
${termPattern}.matcher($subject.toString());
354359
if (${matcher}.find()) {
355360
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
356-
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
361+
if (${matchResult}.group($idx) == null) {
362+
${ev.value} = UTF8String.EMPTY_UTF8;
363+
} else {
364+
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
365+
}
357366
$setEvNotNull
358367
} else {
359368
${ev.value} = UTF8String.EMPTY_UTF8;

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
7878
Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
7979
}
8080

81+
test("non-matching optional group") {
82+
val df = Seq("aaaac").toDF("s")
83+
checkAnswer(
84+
df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)),
85+
Row("")
86+
)
87+
}
88+
8189
test("string ascii function") {
8290
val df = Seq(("abc", "")).toDF("a", "b")
8391
checkAnswer(

0 commit comments

Comments
 (0)