17
17
18
18
package org .apache .spark .sql .catalyst .expressions
19
19
20
- import java .io .UnsupportedEncodingException
20
+ import java .nio .{ByteBuffer , CharBuffer }
21
+ import java .nio .charset .{CharacterCodingException , Charset , CodingErrorAction , IllegalCharsetNameException , UnsupportedCharsetException }
21
22
import java .text .{BreakIterator , DecimalFormat , DecimalFormatSymbols }
22
23
import java .util .{Base64 => JBase64 }
23
24
import java .util .{HashMap , Locale , Map => JMap }
24
25
25
26
import scala .collection .mutable .ArrayBuffer
26
27
27
28
import org .apache .spark .QueryContext
29
+ import org .apache .spark .network .util .JavaUtils
28
30
import org .apache .spark .sql .catalyst .InternalRow
29
31
import org .apache .spark .sql .catalyst .analysis .{ExpressionBuilder , FunctionRegistry , TypeCheckResult }
30
32
import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .DataTypeMismatch
@@ -2716,62 +2718,69 @@ case class Decode(params: Seq[Expression], replacement: Expression)
2716
2718
since = " 1.5.0" ,
2717
2719
group = " string_funcs" )
2718
2720
// scalastyle:on line.size.limit
2719
- case class StringDecode (bin : Expression , charset : Expression , legacyCharsets : Boolean )
2720
- extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
2721
+ case class StringDecode (
2722
+ bin : Expression ,
2723
+ charset : Expression ,
2724
+ legacyCharsets : Boolean ,
2725
+ legacyErrorAction : Boolean )
2726
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
2721
2727
2722
2728
def this (bin : Expression , charset : Expression ) =
2723
- this (bin, charset, SQLConf .get.legacyJavaCharsets)
2729
+ this (bin, charset, SQLConf .get.legacyJavaCharsets, SQLConf .get.legacyCodingErrorAction )
2724
2730
2725
- override def left : Expression = bin
2726
- override def right : Expression = charset
2727
2731
override def dataType : DataType = SQLConf .get.defaultStringType
2728
2732
override def inputTypes : Seq [AbstractDataType ] = Seq (BinaryType , StringTypeAnyCollation )
2733
+ override def prettyName : String = " decode"
2734
+ override def toString : String = s " $prettyName( $bin, $charset) "
2729
2735
2730
- private val supportedCharsets = Set (
2731
- " US-ASCII" , " ISO-8859-1" , " UTF-8" , " UTF-16BE" , " UTF-16LE" , " UTF-16" , " UTF-32" )
2732
-
2733
- protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
2734
- val fromCharset = input2.asInstanceOf [UTF8String ].toString
2735
- try {
2736
- if (legacyCharsets || supportedCharsets.contains(fromCharset.toUpperCase(Locale .ROOT ))) {
2737
- UTF8String .fromString(new String (input1.asInstanceOf [Array [Byte ]], fromCharset))
2738
- } else throw new UnsupportedEncodingException
2739
- } catch {
2740
- case _ : UnsupportedEncodingException =>
2741
- throw QueryExecutionErrors .invalidCharsetError(prettyName, fromCharset)
2742
- }
2743
- }
2744
-
2745
- override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
2746
- nullSafeCodeGen(ctx, ev, (bytes, charset) => {
2747
- val fromCharset = ctx.freshName(" fromCharset" )
2748
- val sc = JavaCode .global(
2749
- ctx.addReferenceObj(" supportedCharsets" , supportedCharsets),
2750
- supportedCharsets.getClass)
2751
- s """
2752
- String $fromCharset = $charset.toString();
2753
- try {
2754
- if ( $legacyCharsets || $sc.contains( $fromCharset.toUpperCase(java.util.Locale.ROOT))) {
2755
- ${ev.value} = UTF8String.fromString(new String( $bytes, $fromCharset));
2756
- } else {
2757
- throw new java.io.UnsupportedEncodingException();
2758
- }
2759
- } catch (java.io.UnsupportedEncodingException e) {
2760
- throw QueryExecutionErrors.invalidCharsetError(" $prettyName", $fromCharset);
2761
- }
2762
- """
2763
- })
2764
- }
2765
-
2766
- override protected def withNewChildrenInternal (
2767
- newLeft : Expression , newRight : Expression ): StringDecode =
2768
- copy(bin = newLeft, charset = newRight)
2736
+ override def replacement : Expression = StaticInvoke (
2737
+ classOf [StringDecode ],
2738
+ SQLConf .get.defaultStringType,
2739
+ " decode" ,
2740
+ Seq (bin, charset, Literal (legacyCharsets), Literal (legacyErrorAction)),
2741
+ Seq (BinaryType , StringTypeAnyCollation , BooleanType , BooleanType ))
2769
2742
2770
- override def prettyName : String = " decode"
2743
+ override def children : Seq [Expression ] = Seq (bin, charset)
2744
+ override protected def withNewChildrenInternal (newChildren : IndexedSeq [Expression ]): Expression =
2745
+ copy(bin = newChildren(0 ), charset = newChildren(1 ))
2771
2746
}
2772
2747
2773
2748
object StringDecode {
2774
2749
def apply (bin : Expression , charset : Expression ): StringDecode = new StringDecode (bin, charset)
2750
+ def decode (
2751
+ input : Array [Byte ],
2752
+ charset : UTF8String ,
2753
+ legacyCharsets : Boolean ,
2754
+ legacyErrorAction : Boolean ): UTF8String = {
2755
+ val fromCharset = charset.toString
2756
+ if (legacyCharsets || Encode .VALID_CHARSETS .contains(fromCharset.toUpperCase(Locale .ROOT ))) {
2757
+ val decoder = try {
2758
+ val codingErrorAction = if (legacyErrorAction) {
2759
+ CodingErrorAction .REPLACE
2760
+ } else {
2761
+ CodingErrorAction .REPORT
2762
+ }
2763
+ Charset .forName(fromCharset)
2764
+ .newDecoder()
2765
+ .onMalformedInput(codingErrorAction)
2766
+ .onUnmappableCharacter(codingErrorAction)
2767
+ } catch {
2768
+ case _ : IllegalCharsetNameException |
2769
+ _ : UnsupportedCharsetException |
2770
+ _ : IllegalArgumentException =>
2771
+ throw QueryExecutionErrors .invalidCharsetError(" decode" , fromCharset)
2772
+ }
2773
+ try {
2774
+ val cb = decoder.decode(ByteBuffer .wrap(input))
2775
+ UTF8String .fromString(cb.toString)
2776
+ } catch {
2777
+ case _ : CharacterCodingException =>
2778
+ throw QueryExecutionErrors .malformedCharacterCoding(" decode" , fromCharset)
2779
+ }
2780
+ } else {
2781
+ throw QueryExecutionErrors .invalidCharsetError(" decode" , fromCharset)
2782
+ }
2783
+ }
2775
2784
}
2776
2785
2777
2786
/**
@@ -2793,59 +2802,76 @@ object StringDecode {
2793
2802
since = " 1.5.0" ,
2794
2803
group = " string_funcs" )
2795
2804
// scalastyle:on line.size.limit
2796
- case class Encode (str : Expression , charset : Expression , legacyCharsets : Boolean )
2797
- extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
2805
+ case class Encode (
2806
+ str : Expression ,
2807
+ charset : Expression ,
2808
+ legacyCharsets : Boolean ,
2809
+ legacyErrorAction : Boolean )
2810
+ extends RuntimeReplaceable with ImplicitCastInputTypes {
2798
2811
2799
2812
def this (value : Expression , charset : Expression ) =
2800
- this (value, charset, SQLConf .get.legacyJavaCharsets)
2813
+ this (value, charset, SQLConf .get.legacyJavaCharsets, SQLConf .get.legacyCodingErrorAction )
2801
2814
2802
- override def left : Expression = str
2803
- override def right : Expression = charset
2804
2815
override def dataType : DataType = BinaryType
2805
2816
override def inputTypes : Seq [AbstractDataType ] =
2806
2817
Seq (StringTypeAnyCollation , StringTypeAnyCollation )
2807
2818
2808
- private val supportedCharsets = Set (
2809
- " US-ASCII" , " ISO-8859-1" , " UTF-8" , " UTF-16BE" , " UTF-16LE" , " UTF-16" , " UTF-32" )
2810
-
2811
- protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
2812
- val toCharset = input2.asInstanceOf [UTF8String ].toString
2813
- try {
2814
- if (legacyCharsets || supportedCharsets.contains(toCharset.toUpperCase(Locale .ROOT ))) {
2815
- input1.asInstanceOf [UTF8String ].toString.getBytes(toCharset)
2816
- } else throw new UnsupportedEncodingException
2817
- } catch {
2818
- case _ : UnsupportedEncodingException =>
2819
- throw QueryExecutionErrors .invalidCharsetError(prettyName, toCharset)
2820
- }
2821
- }
2819
+ override val replacement : Expression = StaticInvoke (
2820
+ classOf [Encode ],
2821
+ BinaryType ,
2822
+ " encode" ,
2823
+ Seq (
2824
+ str, charset, Literal (legacyCharsets, BooleanType ), Literal (legacyErrorAction, BooleanType )),
2825
+ Seq (StringTypeAnyCollation , StringTypeAnyCollation , BooleanType , BooleanType ))
2822
2826
2823
- override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
2824
- nullSafeCodeGen(ctx, ev, (string, charset) => {
2825
- val toCharset = ctx.freshName(" toCharset" )
2826
- val sc = JavaCode .global(
2827
- ctx.addReferenceObj(" supportedCharsets" , supportedCharsets),
2828
- supportedCharsets.getClass)
2829
- s """
2830
- String $toCharset = $charset.toString();
2831
- try {
2832
- if ( $legacyCharsets || $sc.contains( $toCharset.toUpperCase(java.util.Locale.ROOT))) {
2833
- ${ev.value} = $string.toString().getBytes( $toCharset);
2834
- } else {
2835
- throw new java.io.UnsupportedEncodingException();
2836
- }
2837
- } catch (java.io.UnsupportedEncodingException e) {
2838
- throw QueryExecutionErrors.invalidCharsetError(" $prettyName", $toCharset);
2839
- } """
2840
- })
2841
- }
2827
+ override def toString : String = s " $prettyName( $str, $charset) "
2842
2828
2843
- override protected def withNewChildrenInternal (
2844
- newLeft : Expression , newRight : Expression ): Encode = copy(str = newLeft, charset = newRight)
2829
+ override def children : Seq [Expression ] = Seq (str, charset)
2830
+
2831
+ override protected def withNewChildrenInternal (newChildren : IndexedSeq [Expression ]): Expression =
2832
+ copy(str = newChildren.head, charset = newChildren(1 ))
2845
2833
}
2846
2834
2847
2835
object Encode {
2848
2836
def apply (value : Expression , charset : Expression ): Encode = new Encode (value, charset)
2837
+
2838
+ private [expressions] final lazy val VALID_CHARSETS =
2839
+ Set (" US-ASCII" , " ISO-8859-1" , " UTF-8" , " UTF-16BE" , " UTF-16LE" , " UTF-16" , " UTF-32" )
2840
+
2841
+ def encode (
2842
+ input : UTF8String ,
2843
+ charset : UTF8String ,
2844
+ legacyCharsets : Boolean ,
2845
+ legacyErrorAction : Boolean ): Array [Byte ] = {
2846
+ val toCharset = charset.toString
2847
+ if (legacyCharsets || VALID_CHARSETS .contains(toCharset.toUpperCase(Locale .ROOT ))) {
2848
+ val encoder = try {
2849
+ val codingErrorAction = if (legacyErrorAction) {
2850
+ CodingErrorAction .REPLACE
2851
+ } else {
2852
+ CodingErrorAction .REPORT
2853
+ }
2854
+ Charset .forName(toCharset)
2855
+ .newEncoder()
2856
+ .onMalformedInput(codingErrorAction)
2857
+ .onUnmappableCharacter(codingErrorAction)
2858
+ } catch {
2859
+ case _ : IllegalCharsetNameException |
2860
+ _ : UnsupportedCharsetException |
2861
+ _ : IllegalArgumentException =>
2862
+ throw QueryExecutionErrors .invalidCharsetError(" encode" , toCharset)
2863
+ }
2864
+ try {
2865
+ val bb = encoder.encode(CharBuffer .wrap(input.toString))
2866
+ JavaUtils .bufferToArray(bb)
2867
+ } catch {
2868
+ case _ : CharacterCodingException =>
2869
+ throw QueryExecutionErrors .malformedCharacterCoding(" encode" , toCharset)
2870
+ }
2871
+ } else {
2872
+ throw QueryExecutionErrors .invalidCharsetError(" encode" , toCharset)
2873
+ }
2874
+ }
2849
2875
}
2850
2876
2851
2877
/**
0 commit comments