Skip to content

[SPARK-48658][SQL] Encode/Decode functions report coding errors instead of mojibake for unmappable characters #47017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2988,6 +2988,12 @@
],
"sqlState" : "42710"
},
"MALFORMED_CHARACTER_CODING" : {
"message" : [
"Invalid value found when performing <function> with <charset>"
],
"sqlState" : "22000"
},
"MALFORMED_CSV_RECORD" : {
"message" : [
"Malformed CSV record: <badRecord>"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [decode(cast(g#0 as binary), UTF-8, false) AS decode(g, UTF-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.StringDecode, StringType, decode, cast(g#0 as binary), UTF-8, false, false, BinaryType, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS decode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS encode(g, UTF-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.Encode, BinaryType, encode, g#0, UTF-8, false, false, StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS encode(g, UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [encode(g#0, UTF-8, false) AS to_binary(g, utf-8)#0]
Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.Encode, BinaryType, encode, g#0, UTF-8, false, false, StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType, true, true, true) AS to_binary(g, utf-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.catalyst.expressions

import java.io.UnsupportedEncodingException
import java.nio.{ByteBuffer, CharBuffer}
import java.nio.charset.{CharacterCodingException, Charset, CodingErrorAction, IllegalCharsetNameException, UnsupportedCharsetException}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{Base64 => JBase64}
import java.util.{HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.QueryContext
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand Down Expand Up @@ -2708,62 +2710,69 @@ case class Decode(params: Seq[Expression], replacement: Expression)
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class StringDecode(bin: Expression, charset: Expression, legacyCharsets: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
case class StringDecode(
bin: Expression,
charset: Expression,
legacyCharsets: Boolean,
legacyErrorAction: Boolean)
extends RuntimeReplaceable with ImplicitCastInputTypes {

def this(bin: Expression, charset: Expression) =
this(bin, charset, SQLConf.get.legacyJavaCharsets)
this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = bin
override def right: Expression = charset
override def dataType: DataType = SQLConf.get.defaultStringType
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation)
override def prettyName: String = "decode"
override def toString: String = s"$prettyName($bin, $charset)"

private val supportedCharsets = Set(
"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val fromCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(fromCharset.toUpperCase(Locale.ROOT))) {
UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset))
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, fromCharset)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (bytes, charset) => {
val fromCharset = ctx.freshName("fromCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $fromCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($fromCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = UTF8String.fromString(new String($bytes, $fromCharset));
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $fromCharset);
}
"""
})
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StringDecode =
copy(bin = newLeft, charset = newRight)
override def replacement: Expression = StaticInvoke(
classOf[StringDecode],
SQLConf.get.defaultStringType,
"decode",
Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)),
Seq(BinaryType, StringTypeAnyCollation, BooleanType, BooleanType))

override def prettyName: String = "decode"
override def children: Seq[Expression] = Seq(bin, charset)
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(bin = newChildren(0), charset = newChildren(1))
}

object StringDecode {
def apply(bin: Expression, charset: Expression): StringDecode = new StringDecode(bin, charset)
def decode(
input: Array[Byte],
charset: UTF8String,
legacyCharsets: Boolean,
legacyErrorAction: Boolean): UTF8String = {
val fromCharset = charset.toString
if (legacyCharsets || Encode.VALID_CHARSETS.contains(fromCharset.toUpperCase(Locale.ROOT))) {
val decoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(fromCharset)
.newDecoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError("decode", fromCharset)
}
try {
val cb = decoder.decode(ByteBuffer.wrap(input))
UTF8String.fromString(cb.toString)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding("decode", fromCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError("decode", fromCharset)
}
}
}

/**
Expand All @@ -2785,59 +2794,76 @@ object StringDecode {
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class Encode(str: Expression, charset: Expression, legacyCharsets: Boolean)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
case class Encode(
str: Expression,
charset: Expression,
legacyCharsets: Boolean,
legacyErrorAction: Boolean)
extends RuntimeReplaceable with ImplicitCastInputTypes {

def this(value: Expression, charset: Expression) =
this(value, charset, SQLConf.get.legacyJavaCharsets)
this(value, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction)

override def left: Expression = str
override def right: Expression = charset
override def dataType: DataType = BinaryType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)

private val supportedCharsets = Set(
"US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val toCharset = input2.asInstanceOf[UTF8String].toString
try {
if (legacyCharsets || supportedCharsets.contains(toCharset.toUpperCase(Locale.ROOT))) {
input1.asInstanceOf[UTF8String].toString.getBytes(toCharset)
} else throw new UnsupportedEncodingException
} catch {
case _: UnsupportedEncodingException =>
throw QueryExecutionErrors.invalidCharsetError(prettyName, toCharset)
}
}
override val replacement: Expression = StaticInvoke(
classOf[Encode],
BinaryType,
"encode",
Seq(
str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType)),
Seq(StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType))

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (string, charset) => {
val toCharset = ctx.freshName("toCharset")
val sc = JavaCode.global(
ctx.addReferenceObj("supportedCharsets", supportedCharsets),
supportedCharsets.getClass)
s"""
String $toCharset = $charset.toString();
try {
if ($legacyCharsets || $sc.contains($toCharset.toUpperCase(java.util.Locale.ROOT))) {
${ev.value} = $string.toString().getBytes($toCharset);
} else {
throw new java.io.UnsupportedEncodingException();
}
} catch (java.io.UnsupportedEncodingException e) {
throw QueryExecutionErrors.invalidCharsetError("$prettyName", $toCharset);
}"""
})
}
override def toString: String = s"$prettyName($str, $charset)"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Encode = copy(str = newLeft, charset = newRight)
override def children: Seq[Expression] = Seq(str, charset)

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(str = newChildren.head, charset = newChildren(1))
}

object Encode {
def apply(value: Expression, charset: Expression): Encode = new Encode(value, charset)

private[expressions] final lazy val VALID_CHARSETS =
Set("US-ASCII", "ISO-8859-1", "UTF-8", "UTF-16BE", "UTF-16LE", "UTF-16", "UTF-32")

def encode(
input: UTF8String,
charset: UTF8String,
legacyCharsets: Boolean,
legacyErrorAction: Boolean): Array[Byte] = {
val toCharset = charset.toString
if (legacyCharsets || VALID_CHARSETS.contains(toCharset.toUpperCase(Locale.ROOT))) {
val encoder = try {
val codingErrorAction = if (legacyErrorAction) {
CodingErrorAction.REPLACE
} else {
CodingErrorAction.REPORT
}
Charset.forName(toCharset)
.newEncoder()
.onMalformedInput(codingErrorAction)
.onUnmappableCharacter(codingErrorAction)
} catch {
case _: IllegalCharsetNameException |
_: UnsupportedCharsetException |
_: IllegalArgumentException =>
throw QueryExecutionErrors.invalidCharsetError("encode", toCharset)
}
try {
val bb = encoder.encode(CharBuffer.wrap(input.toString))
JavaUtils.bufferToArray(bb)
} catch {
case _: CharacterCodingException =>
throw QueryExecutionErrors.malformedCharacterCoding("encode", toCharset)
}
} else {
throw QueryExecutionErrors.invalidCharsetError("encode", toCharset)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2741,6 +2741,14 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"charset" -> charset))
}

def malformedCharacterCoding(functionName: String, charset: String): RuntimeException = {
new SparkRuntimeException(
errorClass = "MALFORMED_CHARACTER_CODING",
messageParameters = Map(
"function" -> toSQLId(functionName),
"charset" -> charset))
}

def invalidWriterCommitMessageError(details: String): Throwable = {
new SparkRuntimeException(
errorClass = "INVALID_WRITER_COMMIT_MESSAGE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5010,6 +5010,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_CODING_ERROR_ACTION = buildConf("spark.sql.legacy.codingErrorAction")
.internal()
.doc("When set to true, encode/decode functions replace unmappable characters with mojibake " +
"instead of reporting coding errors.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it should be a fallback conf to ANSI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasons I'd like to make it independent of ANSI are:

  • Part of the implication of ANSI is Hive-incompatibility,
  • Hive also reports coding errors, so it was a mistake when we ported this from hive
  • These functions are not ANSI-defined
  • The error behaviors are also not found in ANSI

The reasons mentioned above indicate that this behavior is more of a legacy trait of Spark itself.


val LEGACY_EVAL_CURRENT_TIME = buildConf("spark.sql.legacy.earlyEvalCurrentTime")
.internal()
.doc("When set to true, evaluation and constant folding will happen for now() and " +
Expand Down Expand Up @@ -5986,6 +5994,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def legacyJavaCharsets: Boolean = getConf(SQLConf.LEGACY_JAVA_CHARSETS)

def legacyCodingErrorAction: Boolean = getConf(SQLConf.LEGACY_CODING_ERROR_ACTION)

def legacyEvalCurrentTime: Boolean = getConf(SQLConf.LEGACY_EVAL_CURRENT_TIME)

/** ********************** SQLConf functionality methods ************ */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") {
var strExpr: Expression = Literal("abc")
for (_ <- 1 to 150) {
strExpr = StringDecode(Encode(strExpr, "utf-8"), "utf-8")
strExpr = StringTrimRight(StringTrimLeft(strExpr))
}

val expressions = Seq(If(EqualTo(strExpr, strExpr), strExpr, strExpr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,15 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
new ArrayBasedMapData(keyArray, valueArray)
}

protected def replace(expr: Expression): Expression = expr match {
case r: RuntimeReplaceable => replace(r.replacement)
case _ => expr.mapChildren(replace)
}

private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance()
val resolver = ResolveTimeZone
val expr = resolver.resolveTimeZones(expression)
val expr = resolver.resolveTimeZones(replace(expression))
assert(expr.resolved)
serializer.deserialize(serializer.serialize(expr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringDecode(b, Literal.create(null, StringType)), null, create_row(null))

// Test escaping of charset
GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")) :: Nil)
GenerateUnsafeProjection.generate(Encode(a, Literal("\"quote")).replacement :: Nil)
GenerateUnsafeProjection.generate(StringDecode(b, Literal("\"quote")).replacement :: Nil)
}

test("initcap unit test") {
Expand Down
Loading