diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 113c5f866fd88..01f6c7e0331b0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -154,6 +154,12 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** + * Support for Space Trimming implies that that based on specifier (for now only right trim) + * leading, trailing or both spaces are removed from the input string before comparison. + */ + public final boolean supportsSpaceTrimming; + public Collation( String collationName, String provider, @@ -161,9 +167,11 @@ public Collation( Comparator comparator, String version, ToLongFunction hashFunction, + BiFunction equalsFunction, boolean supportsBinaryEquality, boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { + boolean supportsLowercaseEquality, + boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; this.collator = collator; @@ -173,6 +181,8 @@ public Collation( this.supportsBinaryEquality = supportsBinaryEquality; this.supportsBinaryOrdering = supportsBinaryOrdering; this.supportsLowercaseEquality = supportsLowercaseEquality; + this.equalsFunction = equalsFunction; + this.supportsSpaceTrimming = supportsSpaceTrimming; // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality assert(!supportsBinaryOrdering || supportsBinaryEquality); @@ -180,12 +190,6 @@ public Collation( assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); - - if (supportsBinaryEquality) { - this.equalsFunction = UTF8String::equals; - } else { - this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; - } } /** @@ -538,27 +542,63 @@ private static boolean isValidCollationId(int collationId) { @Override protected Collation buildCollation() { if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + Comparator comparator; + ToLongFunction hashFunction; + BiFunction equalsFunction; + boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = UTF8String::binaryCompare; + hashFunction = s -> (long) s.hashCode(); + equalsFunction = UTF8String::equals; + } else { + comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( + applyTrimmingPolicy(s2, spaceTrimming)); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - UTF8String::binaryCompare, + comparator, "1.0", - s -> (long) s.hashCode(), + hashFunction, + equalsFunction, /* supportsBinaryEquality = */ true, /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } else { + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = CollationAwareUTF8String::compareLowerCase; + hashFunction = s -> + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { + comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - CollationAwareUTF8String::compareLowerCase, + comparator, "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true); + /* supportsLowercaseEquality = */ true, + spaceTrimming != SpaceTrimming.NONE); } } @@ -917,16 +957,35 @@ protected Collation buildCollation() { Collator collator = Collator.getInstance(resultLocale); // Freeze ICU collator to ensure thread safety. collator.freeze(); + + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + hashFunction = s -> (long) collator.getCollationKey( + s.toValidString()).hashCode(); + comparator = (s1, s2) -> + collator.compare(s1.toValidString(), s2.toValidString()); + } else { + comparator = (s1, s2) -> collator.compare( + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + hashFunction = s -> (long) collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_ICU, collator, - (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()), + comparator, ICU_COLLATOR_VERSION, - s -> (long) collator.getCollationKey(s.toValidString()).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } @Override @@ -1103,14 +1162,6 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { Collation.CollationSpecICU.AccentSensitivity.AI; } - /** - * Returns whether the collation uses trim collation for the given collation id. - */ - public static boolean usesTrimCollation(int collationId) { - return Collation.CollationSpec.getSpaceTrimming(collationId) != - Collation.CollationSpec.SpaceTrimming.NONE; - } - public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -1137,7 +1188,7 @@ public static String[] getICULocaleNames() { public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } if (collation.supportsBinaryEquality) { @@ -1153,7 +1204,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } if (collation.supportsBinaryEquality) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 66ff551193101..a565d2d347636 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -127,6 +127,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -134,15 +139,30 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), - CollationTestCase("UNICODE_CI", "Å", "a\u030A", true) + CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true) ) checks.foreach(testCase => { @@ -162,19 +182,48 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), - CollationTestCase("UNICODE_CI", "aaa", "bbb", -1)) + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) + ) checks.foreach(testCase => { val collation = fetchCollation(testCase.collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 29d48e3d1f47f..1c93c2ad550e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -48,7 +48,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) private[sql] def usesTrimCollation: Boolean = - CollationFactory.usesTrimCollation(collationId) + CollationFactory.fetchCollation(collationId).supportsSpaceTrimming private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 3a667f370428e..7128190902550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression { protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction { hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) case s: UTF8String => val st = dataType.asInstanceOf[StringType] - if (st.supportsBinaryEquality) { + if (st.supportsBinaryEquality && !st.usesTrimCollation) { hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) } else { val stringHash = CollationFactory @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index e296b5be6134b..40b8bccafaad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -205,7 +205,9 @@ object UnsafeRowUtils { * can lead to rows being semantically equal even though their binary representations differ). */ def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { - case st: StringType => !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality + case st: StringType => + val collation = CollationFactory.fetchCollation(st.collationId) + (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 45a71b4da7287..3b1f349520f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,8 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality => + case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + hashBytes(s"$input.getBytes()") + case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 4c3cd93873bd4..ce6818652d2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -49,9 +49,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Md5TestCase("Spark", "UTF8_BINARY", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_BINARY_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("Spark", "UTF8_LCASE", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_LCASE_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("SQL", "UNICODE", "9778840a0100cb30c982876741b0b5a2"), - Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2") + Md5TestCase("SQL", "UNICODE_RTRIM", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2") ) // Supported collations @@ -81,11 +85,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha2TestCase("Spark", "UTF8_BINARY", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_BINARY_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("Spark", "UTF8_LCASE", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_LCASE_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("SQL", "UNICODE", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_RTRIM", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), Sha2TestCase("SQL", "UNICODE_CI", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_CI_RTRIM", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35") ) @@ -114,9 +126,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha1TestCase("Spark", "UTF8_BINARY", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_BINARY_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("Spark", "UTF8_LCASE", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_LCASE_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("SQL", "UNICODE", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), - Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") + Sha1TestCase("SQL", "UNICODE_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") ) // Supported collations @@ -144,9 +160,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Crc321TestCase("Spark", "UTF8_BINARY", 1557323817), + Crc321TestCase("Spark", "UTF8_BINARY_RTRIM", 1557323817), Crc321TestCase("Spark", "UTF8_LCASE", 1557323817), + Crc321TestCase("Spark", "UTF8_LCASE_RTRIM", 1557323817), Crc321TestCase("SQL", "UNICODE", 1299261525), - Crc321TestCase("SQL", "UNICODE_CI", 1299261525) + Crc321TestCase("SQL", "UNICODE_RTRIM", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI_RTRIM", 1299261525) ) // Supported collations @@ -172,9 +192,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Murmur3HashTestCase("Spark", "UTF8_BINARY", 228093765), + Murmur3HashTestCase("Spark ", "UTF8_BINARY_RTRIM", 1779328737), Murmur3HashTestCase("Spark", "UTF8_LCASE", -1928694360), + Murmur3HashTestCase("Spark ", "UTF8_LCASE_RTRIM", -1928694360), Murmur3HashTestCase("SQL", "UNICODE", -1923567940), - Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950) + Murmur3HashTestCase("SQL ", "UNICODE_RTRIM", -1923567940), + Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950), + Murmur3HashTestCase("SQL ", "UNICODE_CI_RTRIM", 1029527950) ) // Supported collations @@ -200,9 +224,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( XxHash64TestCase("Spark", "UTF8_BINARY", -4294468057691064905L), + XxHash64TestCase("Spark ", "UTF8_BINARY_RTRIM", 6480371823304753502L), XxHash64TestCase("Spark", "UTF8_LCASE", -3142112654825786434L), + XxHash64TestCase("Spark ", "UTF8_LCASE_RTRIM", -3142112654825786434L), XxHash64TestCase("SQL", "UNICODE", 5964849564945649886L), - XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L) + XxHash64TestCase("SQL ", "UNICODE_RTRIM", 5964849564945649886L), + XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L), + XxHash64TestCase("SQL ", "UNICODE_CI_RTRIM", 3732497619779520590L) ) // Supported collations diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b19af542dabf2..4234d73c1794d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -101,8 +101,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function syntax") { assert(sql(s"select collate('aaa', 'utf8_binary')").schema(0).dataType == StringType("UTF8_BINARY")) + assert(sql(s"select collate('aaa', 'utf8_binary_rtrim')").schema(0).dataType == + StringType("UTF8_BINARY_RTRIM")) assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == StringType("UTF8_LCASE")) + assert(sql(s"select collate('aaa', 'utf8_lcase_rtrim')").schema(0).dataType == + StringType("UTF8_LCASE_RTRIM")) } test("collate function syntax with default collation set") { @@ -260,14 +264,23 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq( ("utf8_binary", "aaa", "AAA", false), ("utf8_binary", "aaa", "aaa", true), + ("utf8_binary_rtrim", "aaa", "AAA", false), + ("utf8_binary_rtrim", "aaa", "aaa ", true), ("utf8_lcase", "aaa", "aaa", true), ("utf8_lcase", "aaa", "AAA", true), ("utf8_lcase", "aaa", "bbb", false), + ("utf8_lcase_rtrim", "aaa", "AAA ", true), + ("utf8_lcase_rtrim", "aaa", "bbb", false), ("unicode", "aaa", "aaa", true), ("unicode", "aaa", "AAA", false), + ("unicode_rtrim", "aaa ", "aaa ", true), + ("unicode_rtrim", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false) + ("unicode_CI", "aaa", "bbb", false), + ("unicode_CI_rtrim", "aaa", "aaa", true), + ("unicode_CI_rtrim", "aaa ", "AAA ", true), + ("unicode_CI_rtrim", "aaa", "bbb", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -284,15 +297,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("utf8_binary", "AAA", "aaa", true), ("utf8_binary", "aaa", "aaa", false), ("utf8_binary", "aaa", "BBB", false), + ("utf8_binary_rtrim", "aaa ", "aaa ", false), ("utf8_lcase", "aaa", "aaa", false), ("utf8_lcase", "AAA", "aaa", false), ("utf8_lcase", "aaa", "bbb", true), + ("utf8_lcase_rtrim", "AAA ", "aaa", false), ("unicode", "aaa", "aaa", false), ("unicode", "aaa", "AAA", true), ("unicode", "aaa", "BBB", true), + ("unicode_rtrim", "aaa ", "aaa", false), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true) + ("unicode_CI", "aaa", "bbb", true), + ("unicode_CI_rtrim", "aaa ", "aaa", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -355,18 +372,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( + ("utf8_binary_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("utf8_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), ("utf8_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("utf8_lcase_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))), ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_CI_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => checkAnswer(sql(