@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
19
19
20
20
import java .text .DecimalFormat
21
21
import java .util .Locale
22
- import java .util .regex .Pattern
22
+ import java .util .regex .{ MatchResult , Pattern }
23
23
24
24
import org .apache .spark .sql .catalyst .InternalRow
25
25
import org .apache .spark .sql .catalyst .analysis .UnresolvedException
@@ -876,6 +876,221 @@ case class Encode(value: Expression, charset: Expression)
876
876
}
877
877
}
878
878
879
+ /**
880
+ * Replace all substrings of str that match regexp with rep.
881
+ *
882
+ * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
883
+ */
884
+ case class RegExpReplace (subject : Expression , regexp : Expression , rep : Expression )
885
+ extends Expression with ImplicitCastInputTypes {
886
+
887
+ // last regex in string, we will update the pattern iff regexp value changed.
888
+ @ transient private var lastRegex : UTF8String = _
889
+ // last regex pattern, we cache it for performance concern
890
+ @ transient private var pattern : Pattern = _
891
+ // last replacement string, we don't want to convert a UTF8String => java.langString every time.
892
+ @ transient private var lastReplacement : String = _
893
+ @ transient private var lastReplacementInUTF8 : UTF8String = _
894
+ // result buffer write by Matcher
895
+ @ transient private val result : StringBuffer = new StringBuffer
896
+
897
+ override def nullable : Boolean = subject.nullable || regexp.nullable || rep.nullable
898
+ override def foldable : Boolean = subject.foldable && regexp.foldable && rep.foldable
899
+
900
+ override def eval (input : InternalRow ): Any = {
901
+ val s = subject.eval(input)
902
+ if (null != s) {
903
+ val p = regexp.eval(input)
904
+ if (null != p) {
905
+ val r = rep.eval(input)
906
+ if (null != r) {
907
+ if (! p.equals(lastRegex)) {
908
+ // regex value changed
909
+ lastRegex = p.asInstanceOf [UTF8String ]
910
+ pattern = Pattern .compile(lastRegex.toString)
911
+ }
912
+ if (! r.equals(lastReplacementInUTF8)) {
913
+ // replacement string changed
914
+ lastReplacementInUTF8 = r.asInstanceOf [UTF8String ]
915
+ lastReplacement = lastReplacementInUTF8.toString
916
+ }
917
+ val m = pattern.matcher(s.toString())
918
+ result.delete(0 , result.length())
919
+
920
+ while (m.find) {
921
+ m.appendReplacement(result, lastReplacement)
922
+ }
923
+ m.appendTail(result)
924
+
925
+ return UTF8String .fromString(result.toString)
926
+ }
927
+ }
928
+ }
929
+
930
+ null
931
+ }
932
+
933
+ override def dataType : DataType = StringType
934
+ override def inputTypes : Seq [AbstractDataType ] = Seq (StringType , StringType , StringType )
935
+ override def children : Seq [Expression ] = subject :: regexp :: rep :: Nil
936
+ override def prettyName : String = " regexp_replace"
937
+
938
+ override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
939
+ val termLastRegex = ctx.freshName(" lastRegex" )
940
+ val termPattern = ctx.freshName(" pattern" )
941
+
942
+ val termLastReplacement = ctx.freshName(" lastReplacement" )
943
+ val termLastReplacementInUTF8 = ctx.freshName(" lastReplacementInUTF8" )
944
+
945
+ val termResult = ctx.freshName(" result" )
946
+
947
+ val classNameUTF8String = classOf [UTF8String ].getCanonicalName
948
+ val classNamePattern = classOf [Pattern ].getCanonicalName
949
+ val classNameString = classOf [java.lang.String ].getCanonicalName
950
+ val classNameStringBuffer = classOf [java.lang.StringBuffer ].getCanonicalName
951
+
952
+ ctx.addMutableState(classNameUTF8String,
953
+ termLastRegex, s " ${termLastRegex} = null; " )
954
+ ctx.addMutableState(classNamePattern,
955
+ termPattern, s " ${termPattern} = null; " )
956
+ ctx.addMutableState(classNameString,
957
+ termLastReplacement, s " ${termLastReplacement} = null; " )
958
+ ctx.addMutableState(classNameUTF8String,
959
+ termLastReplacementInUTF8, s " ${termLastReplacementInUTF8} = null; " )
960
+ ctx.addMutableState(classNameStringBuffer,
961
+ termResult, s " ${termResult} = new $classNameStringBuffer(); " )
962
+
963
+ val evalSubject = subject.gen(ctx)
964
+ val evalRegexp = regexp.gen(ctx)
965
+ val evalRep = rep.gen(ctx)
966
+
967
+ s """
968
+ ${evalSubject.code}
969
+ boolean ${ev.isNull} = ${evalSubject.isNull};
970
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
971
+ if (! ${evalSubject.isNull}) {
972
+ ${evalRegexp.code}
973
+ if (! ${evalRegexp.isNull}) {
974
+ ${evalRep.code}
975
+ if (! ${evalRep.isNull}) {
976
+ if (! ${evalRegexp.primitive}.equals( ${termLastRegex})) {
977
+ // regex value changed
978
+ ${termLastRegex} = ${evalRegexp.primitive};
979
+ ${termPattern} = ${classNamePattern}.compile( ${termLastRegex}.toString());
980
+ }
981
+ if (! ${evalRep.primitive}.equals( ${termLastReplacementInUTF8})) {
982
+ // replacement string changed
983
+ ${termLastReplacementInUTF8} = ${evalRep.primitive};
984
+ ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
985
+ }
986
+ ${termResult}.delete(0, ${termResult}.length());
987
+ ${classOf [java.util.regex.Matcher ].getCanonicalName} m =
988
+ ${termPattern}.matcher( ${evalSubject.primitive}.toString());
989
+
990
+ while (m.find()) {
991
+ m.appendReplacement( ${termResult}, ${termLastReplacement});
992
+ }
993
+ m.appendTail( ${termResult});
994
+ ${ev.primitive} = ${classNameUTF8String}.fromString( ${termResult}.toString());
995
+ ${ev.isNull} = false;
996
+ }
997
+ }
998
+ }
999
+ """
1000
+ }
1001
+ }
1002
+
1003
+ /**
1004
+ * Extract a specific(idx) group identified by a Java regex.
1005
+ *
1006
+ * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
1007
+ */
1008
+ case class RegExpExtract (subject : Expression , regexp : Expression , idx : Expression )
1009
+ extends Expression with ImplicitCastInputTypes {
1010
+ def this (s : Expression , r : Expression ) = this (s, r, Literal (1 ))
1011
+
1012
+ // last regex in string, we will update the pattern iff regexp value changed.
1013
+ @ transient private var lastRegex : UTF8String = _
1014
+ // last regex pattern, we cache it for performance concern
1015
+ @ transient private var pattern : Pattern = _
1016
+
1017
+ override def nullable : Boolean = subject.nullable || regexp.nullable || idx.nullable
1018
+ override def foldable : Boolean = subject.foldable && regexp.foldable && idx.foldable
1019
+
1020
+ override def eval (input : InternalRow ): Any = {
1021
+ val s = subject.eval(input)
1022
+ if (null != s) {
1023
+ val p = regexp.eval(input)
1024
+ if (null != p) {
1025
+ val r = idx.eval(input)
1026
+ if (null != r) {
1027
+ if (! p.equals(lastRegex)) {
1028
+ // regex value changed
1029
+ lastRegex = p.asInstanceOf [UTF8String ]
1030
+ pattern = Pattern .compile(lastRegex.toString)
1031
+ }
1032
+ val m = pattern.matcher(s.toString())
1033
+ if (m.find) {
1034
+ val mr : MatchResult = m.toMatchResult
1035
+ return UTF8String .fromString(mr.group(r.asInstanceOf [Int ]))
1036
+ }
1037
+ return UTF8String .EMPTY_UTF8
1038
+ }
1039
+ }
1040
+ }
1041
+
1042
+ null
1043
+ }
1044
+
1045
+ override def dataType : DataType = StringType
1046
+ override def inputTypes : Seq [AbstractDataType ] = Seq (StringType , StringType , IntegerType )
1047
+ override def children : Seq [Expression ] = subject :: regexp :: idx :: Nil
1048
+ override def prettyName : String = " regexp_extract"
1049
+
1050
+ override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
1051
+ val termLastRegex = ctx.freshName(" lastRegex" )
1052
+ val termPattern = ctx.freshName(" pattern" )
1053
+ val classNameUTF8String = classOf [UTF8String ].getCanonicalName
1054
+ val classNamePattern = classOf [Pattern ].getCanonicalName
1055
+
1056
+ ctx.addMutableState(classNameUTF8String, termLastRegex, s " ${termLastRegex} = null; " )
1057
+ ctx.addMutableState(classNamePattern, termPattern, s " ${termPattern} = null; " )
1058
+
1059
+ val evalSubject = subject.gen(ctx)
1060
+ val evalRegexp = regexp.gen(ctx)
1061
+ val evalIdx = idx.gen(ctx)
1062
+
1063
+ s """
1064
+ ${ctx.javaType(dataType)} ${ev.primitive} = null;
1065
+ boolean ${ev.isNull} = true;
1066
+ ${evalSubject.code}
1067
+ if (! ${evalSubject.isNull}) {
1068
+ ${evalRegexp.code}
1069
+ if (! ${evalRegexp.isNull}) {
1070
+ ${evalIdx.code}
1071
+ if (! ${evalIdx.isNull}) {
1072
+ if (! ${evalRegexp.primitive}.equals( ${termLastRegex})) {
1073
+ // regex value changed
1074
+ ${termLastRegex} = ${evalRegexp.primitive};
1075
+ ${termPattern} = ${classNamePattern}.compile( ${termLastRegex}.toString());
1076
+ }
1077
+ ${classOf [java.util.regex.Matcher ].getCanonicalName} m =
1078
+ ${termPattern}.matcher( ${evalSubject.primitive}.toString());
1079
+ if (m.find()) {
1080
+ ${classOf [java.util.regex.MatchResult ].getCanonicalName} mr = m.toMatchResult();
1081
+ ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group( ${evalIdx.primitive}));
1082
+ ${ev.isNull} = false;
1083
+ } else {
1084
+ ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
1085
+ ${ev.isNull} = false;
1086
+ }
1087
+ }
1088
+ }
1089
+ }
1090
+ """
1091
+ }
1092
+ }
1093
+
879
1094
/**
880
1095
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
881
1096
* and returns the result as a string. If D is 0, the result has no decimal point or
0 commit comments