Skip to content

Commit 8c8f0ef

Browse files
chenghao-inteldavies
authored andcommitted
[SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace
Add expressions `regex_extract` & `regex_replace` Author: Cheng Hao <hao.cheng@intel.com> Closes #7468 from chenghao-intel/regexp and squashes the following commits: e5ea476 [Cheng Hao] minor update for documentation ef96fd6 [Cheng Hao] update the code gen 72cf28f [Cheng Hao] Add more log for compilation error 4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support
1 parent d38c502 commit 8c8f0ef

File tree

8 files changed

+323
-4
lines changed

8 files changed

+323
-4
lines changed

python/pyspark/sql/functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
'monotonicallyIncreasingId',
4747
'rand',
4848
'randn',
49+
'regexp_extract',
50+
'regexp_replace',
4951
'sha1',
5052
'sha2',
5153
'sparkPartitionId',
@@ -343,6 +345,34 @@ def levenshtein(left, right):
343345
return Column(jc)
344346

345347

348+
@ignore_unicode_prefix
349+
@since(1.5)
350+
def regexp_extract(str, pattern, idx):
351+
"""Extract a specific(idx) group identified by a java regex, from the specified string column.
352+
353+
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
354+
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
355+
[Row(d=u'100')]
356+
"""
357+
sc = SparkContext._active_spark_context
358+
jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
359+
return Column(jc)
360+
361+
362+
@ignore_unicode_prefix
363+
@since(1.5)
364+
def regexp_replace(str, pattern, replacement):
365+
"""Replace all substrings of the specified string value that match regexp with rep.
366+
367+
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
368+
>>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
369+
[Row(d=u'##-##')]
370+
"""
371+
sc = SparkContext._active_spark_context
372+
jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
373+
return Column(jc)
374+
375+
346376
@ignore_unicode_prefix
347377
@since(1.5)
348378
def md5(col):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ object FunctionRegistry {
161161
expression[Lower]("lower"),
162162
expression[Length]("length"),
163163
expression[Levenshtein]("levenshtein"),
164+
expression[RegExpExtract]("regexp_extract"),
165+
expression[RegExpReplace]("regexp_replace"),
164166
expression[StringInstr]("instr"),
165167
expression[StringLocate]("locate"),
166168
expression[StringLPad]("lpad"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
297297
evaluator.cook(code)
298298
} catch {
299299
case e: Exception =>
300-
logError(s"failed to compile:\n $code", e)
301-
throw e
300+
val msg = s"failed to compile:\n $code"
301+
logError(msg, e)
302+
throw new Exception(msg, e)
302303
}
303304
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
304305
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala

Lines changed: 216 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.text.DecimalFormat
2121
import java.util.Locale
22-
import java.util.regex.Pattern
22+
import java.util.regex.{MatchResult, Pattern}
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
@@ -876,6 +876,221 @@ case class Encode(value: Expression, charset: Expression)
876876
}
877877
}
878878

879+
/**
880+
* Replace all substrings of str that match regexp with rep.
881+
*
882+
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
883+
*/
884+
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
885+
extends Expression with ImplicitCastInputTypes {
886+
887+
// last regex in string, we will update the pattern iff regexp value changed.
888+
@transient private var lastRegex: UTF8String = _
889+
// last regex pattern, we cache it for performance concern
890+
@transient private var pattern: Pattern = _
891+
// last replacement string, we don't want to convert a UTF8String => java.langString every time.
892+
@transient private var lastReplacement: String = _
893+
@transient private var lastReplacementInUTF8: UTF8String = _
894+
// result buffer write by Matcher
895+
@transient private val result: StringBuffer = new StringBuffer
896+
897+
override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
898+
override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
899+
900+
override def eval(input: InternalRow): Any = {
901+
val s = subject.eval(input)
902+
if (null != s) {
903+
val p = regexp.eval(input)
904+
if (null != p) {
905+
val r = rep.eval(input)
906+
if (null != r) {
907+
if (!p.equals(lastRegex)) {
908+
// regex value changed
909+
lastRegex = p.asInstanceOf[UTF8String]
910+
pattern = Pattern.compile(lastRegex.toString)
911+
}
912+
if (!r.equals(lastReplacementInUTF8)) {
913+
// replacement string changed
914+
lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
915+
lastReplacement = lastReplacementInUTF8.toString
916+
}
917+
val m = pattern.matcher(s.toString())
918+
result.delete(0, result.length())
919+
920+
while (m.find) {
921+
m.appendReplacement(result, lastReplacement)
922+
}
923+
m.appendTail(result)
924+
925+
return UTF8String.fromString(result.toString)
926+
}
927+
}
928+
}
929+
930+
null
931+
}
932+
933+
override def dataType: DataType = StringType
934+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
935+
override def children: Seq[Expression] = subject :: regexp :: rep :: Nil
936+
override def prettyName: String = "regexp_replace"
937+
938+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
939+
val termLastRegex = ctx.freshName("lastRegex")
940+
val termPattern = ctx.freshName("pattern")
941+
942+
val termLastReplacement = ctx.freshName("lastReplacement")
943+
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
944+
945+
val termResult = ctx.freshName("result")
946+
947+
val classNameUTF8String = classOf[UTF8String].getCanonicalName
948+
val classNamePattern = classOf[Pattern].getCanonicalName
949+
val classNameString = classOf[java.lang.String].getCanonicalName
950+
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
951+
952+
ctx.addMutableState(classNameUTF8String,
953+
termLastRegex, s"${termLastRegex} = null;")
954+
ctx.addMutableState(classNamePattern,
955+
termPattern, s"${termPattern} = null;")
956+
ctx.addMutableState(classNameString,
957+
termLastReplacement, s"${termLastReplacement} = null;")
958+
ctx.addMutableState(classNameUTF8String,
959+
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
960+
ctx.addMutableState(classNameStringBuffer,
961+
termResult, s"${termResult} = new $classNameStringBuffer();")
962+
963+
val evalSubject = subject.gen(ctx)
964+
val evalRegexp = regexp.gen(ctx)
965+
val evalRep = rep.gen(ctx)
966+
967+
s"""
968+
${evalSubject.code}
969+
boolean ${ev.isNull} = ${evalSubject.isNull};
970+
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
971+
if (!${evalSubject.isNull}) {
972+
${evalRegexp.code}
973+
if (!${evalRegexp.isNull}) {
974+
${evalRep.code}
975+
if (!${evalRep.isNull}) {
976+
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
977+
// regex value changed
978+
${termLastRegex} = ${evalRegexp.primitive};
979+
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
980+
}
981+
if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
982+
// replacement string changed
983+
${termLastReplacementInUTF8} = ${evalRep.primitive};
984+
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
985+
}
986+
${termResult}.delete(0, ${termResult}.length());
987+
${classOf[java.util.regex.Matcher].getCanonicalName} m =
988+
${termPattern}.matcher(${evalSubject.primitive}.toString());
989+
990+
while (m.find()) {
991+
m.appendReplacement(${termResult}, ${termLastReplacement});
992+
}
993+
m.appendTail(${termResult});
994+
${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
995+
${ev.isNull} = false;
996+
}
997+
}
998+
}
999+
"""
1000+
}
1001+
}
1002+
1003+
/**
1004+
* Extract a specific(idx) group identified by a Java regex.
1005+
*
1006+
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
1007+
*/
1008+
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
1009+
extends Expression with ImplicitCastInputTypes {
1010+
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
1011+
1012+
// last regex in string, we will update the pattern iff regexp value changed.
1013+
@transient private var lastRegex: UTF8String = _
1014+
// last regex pattern, we cache it for performance concern
1015+
@transient private var pattern: Pattern = _
1016+
1017+
override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
1018+
override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
1019+
1020+
override def eval(input: InternalRow): Any = {
1021+
val s = subject.eval(input)
1022+
if (null != s) {
1023+
val p = regexp.eval(input)
1024+
if (null != p) {
1025+
val r = idx.eval(input)
1026+
if (null != r) {
1027+
if (!p.equals(lastRegex)) {
1028+
// regex value changed
1029+
lastRegex = p.asInstanceOf[UTF8String]
1030+
pattern = Pattern.compile(lastRegex.toString)
1031+
}
1032+
val m = pattern.matcher(s.toString())
1033+
if (m.find) {
1034+
val mr: MatchResult = m.toMatchResult
1035+
return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
1036+
}
1037+
return UTF8String.EMPTY_UTF8
1038+
}
1039+
}
1040+
}
1041+
1042+
null
1043+
}
1044+
1045+
override def dataType: DataType = StringType
1046+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
1047+
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
1048+
override def prettyName: String = "regexp_extract"
1049+
1050+
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
1051+
val termLastRegex = ctx.freshName("lastRegex")
1052+
val termPattern = ctx.freshName("pattern")
1053+
val classNameUTF8String = classOf[UTF8String].getCanonicalName
1054+
val classNamePattern = classOf[Pattern].getCanonicalName
1055+
1056+
ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
1057+
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
1058+
1059+
val evalSubject = subject.gen(ctx)
1060+
val evalRegexp = regexp.gen(ctx)
1061+
val evalIdx = idx.gen(ctx)
1062+
1063+
s"""
1064+
${ctx.javaType(dataType)} ${ev.primitive} = null;
1065+
boolean ${ev.isNull} = true;
1066+
${evalSubject.code}
1067+
if (!${evalSubject.isNull}) {
1068+
${evalRegexp.code}
1069+
if (!${evalRegexp.isNull}) {
1070+
${evalIdx.code}
1071+
if (!${evalIdx.isNull}) {
1072+
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
1073+
// regex value changed
1074+
${termLastRegex} = ${evalRegexp.primitive};
1075+
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
1076+
}
1077+
${classOf[java.util.regex.Matcher].getCanonicalName} m =
1078+
${termPattern}.matcher(${evalSubject.primitive}.toString());
1079+
if (m.find()) {
1080+
${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
1081+
${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
1082+
${ev.isNull} = false;
1083+
} else {
1084+
${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
1085+
${ev.isNull} = false;
1086+
}
1087+
}
1088+
}
1089+
}
1090+
"""
1091+
}
1092+
}
1093+
8791094
/**
8801095
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
8811096
* and returns the result as a string. If D is 0, the result has no decimal point or

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ trait ExpressionEvalHelper {
7979
fail(
8080
s"""
8181
|Code generation of $expression failed:
82-
|${evaluated.code}
8382
|$e
8483
""".stripMargin)
8584
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
464464
checkEvaluation(StringSpace(s1), null, row2)
465465
}
466466

467+
test("RegexReplace") {
468+
val row1 = create_row("100-200", "(\\d+)", "num")
469+
val row2 = create_row("100-200", "(\\d+)", "###")
470+
val row3 = create_row("100-200", "(-)", "###")
471+
472+
val s = 's.string.at(0)
473+
val p = 'p.string.at(1)
474+
val r = 'r.string.at(2)
475+
476+
val expr = RegExpReplace(s, p, r)
477+
checkEvaluation(expr, "num-num", row1)
478+
checkEvaluation(expr, "###-###", row2)
479+
checkEvaluation(expr, "100###200", row3)
480+
}
481+
482+
test("RegexExtract") {
483+
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
484+
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
485+
val row3 = create_row("100-200", "(\\d+).*", 1)
486+
val row4 = create_row("100-200", "([a-z])", 1)
487+
488+
val s = 's.string.at(0)
489+
val p = 'p.string.at(1)
490+
val r = 'r.int.at(2)
491+
492+
val expr = RegExpExtract(s, p, r)
493+
checkEvaluation(expr, "100", row1)
494+
checkEvaluation(expr, "200", row2)
495+
checkEvaluation(expr, "100", row3)
496+
checkEvaluation(expr, "", row4) // will not match anything, empty string get
497+
498+
val expr1 = new RegExpExtract(s, p)
499+
checkEvaluation(expr1, "100", row1)
500+
}
501+
467502
test("SPLIT") {
468503
val s1 = 'a.string.at(0)
469504
val s2 = 'b.string.at(1)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,27 @@ object functions {
17811781
StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
17821782
}
17831783

1784+
1785+
/**
1786+
* Extract a specific(idx) group identified by a java regex, from the specified string column.
1787+
*
1788+
* @group string_funcs
1789+
* @since 1.5.0
1790+
*/
1791+
def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = {
1792+
RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
1793+
}
1794+
1795+
/**
1796+
* Replace all substrings of the specified string value that match regexp with rep.
1797+
*
1798+
* @group string_funcs
1799+
* @since 1.5.0
1800+
*/
1801+
def regexp_replace(e: Column, pattern: String, replacement: String): Column = {
1802+
RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr)
1803+
}
1804+
17841805
/**
17851806
* Computes the BASE64 encoding of a binary column and returns it as a string column.
17861807
* This is the reverse of unbase64.

sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest {
5656
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
5757
}
5858

59+
test("string regex_replace / regex_extract") {
60+
val df = Seq(("100-200", "")).toDF("a", "b")
61+
62+
checkAnswer(
63+
df.select(
64+
regexp_replace($"a", "(\\d+)", "num"),
65+
regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
66+
Row("num-num", "100"))
67+
68+
checkAnswer(
69+
df.selectExpr(
70+
"regexp_replace(a, '(\\d+)', 'num')",
71+
"regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
72+
Row("num-num", "200"))
73+
}
74+
5975
test("string ascii function") {
6076
val df = Seq(("abc", "")).toDF("a", "b")
6177
checkAnswer(

0 commit comments

Comments
 (0)