Skip to content

Commit

Permalink
[SPARK-49661][SQL] Implement trim collation hashing and comparison
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Implement support for hashing and comparison for trim collation.

### Why are the changes needed?
To have full support for trim collation.

### How was this patch tested?
Add tests in CollationFactorySUite and CollationSqlExpressionSuite.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48386 from jovanpavl-db/implement_hashing.

Authored-by: Jovan Pavlovic <jovan.pavlovic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
jovanpavl-db authored and MaxGekk committed Oct 14, 2024
1 parent 96c4953 commit 74aed77
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,24 @@ public static class Collation {
*/
public final boolean supportsLowercaseEquality;

/**
* Support for Space Trimming implies that that based on specifier (for now only right trim)
* leading, trailing or both spaces are removed from the input string before comparison.
*/
public final boolean supportsSpaceTrimming;

public Collation(
String collationName,
String provider,
Collator collator,
Comparator<UTF8String> comparator,
String version,
ToLongFunction<UTF8String> hashFunction,
BiFunction<UTF8String, UTF8String, Boolean> equalsFunction,
boolean supportsBinaryEquality,
boolean supportsBinaryOrdering,
boolean supportsLowercaseEquality) {
boolean supportsLowercaseEquality,
boolean supportsSpaceTrimming) {
this.collationName = collationName;
this.provider = provider;
this.collator = collator;
Expand All @@ -173,19 +181,15 @@ public Collation(
this.supportsBinaryEquality = supportsBinaryEquality;
this.supportsBinaryOrdering = supportsBinaryOrdering;
this.supportsLowercaseEquality = supportsLowercaseEquality;
this.equalsFunction = equalsFunction;
this.supportsSpaceTrimming = supportsSpaceTrimming;

// De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality
assert(!supportsBinaryOrdering || supportsBinaryEquality);
// No Collation can simultaneously support binary equality and lowercase equality
assert(!supportsBinaryEquality || !supportsLowercaseEquality);

assert(SUPPORTED_PROVIDERS.contains(provider));

if (supportsBinaryEquality) {
this.equalsFunction = UTF8String::equals;
} else {
this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0;
}
}

/**
Expand Down Expand Up @@ -538,27 +542,63 @@ private static boolean isValidCollationId(int collationId) {
@Override
protected Collation buildCollation() {
if (caseSensitivity == CaseSensitivity.UNSPECIFIED) {
Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;
BiFunction<UTF8String, UTF8String, Boolean> equalsFunction;
boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE;

if (spaceTrimming == SpaceTrimming.NONE) {
comparator = UTF8String::binaryCompare;
hashFunction = s -> (long) s.hashCode();
equalsFunction = UTF8String::equals;
} else {
comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare(
applyTrimmingPolicy(s2, spaceTrimming));
hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode();
equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals(
applyTrimmingPolicy(s2, spaceTrimming));
}

return new Collation(
normalizedCollationName(),
PROVIDER_SPARK,
null,
UTF8String::binaryCompare,
comparator,
"1.0",
s -> (long) s.hashCode(),
hashFunction,
equalsFunction,
/* supportsBinaryEquality = */ true,
/* supportsBinaryOrdering = */ true,
/* supportsLowercaseEquality = */ false);
/* supportsLowercaseEquality = */ false,
spaceTrimming != SpaceTrimming.NONE);
} else {
Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;

if (spaceTrimming == SpaceTrimming.NONE) {
comparator = CollationAwareUTF8String::compareLowerCase;
hashFunction = s ->
(long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode();
} else {
comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase(
applyTrimmingPolicy(s1, spaceTrimming),
applyTrimmingPolicy(s2, spaceTrimming));
hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(
applyTrimmingPolicy(s, spaceTrimming)).hashCode();
}

return new Collation(
normalizedCollationName(),
PROVIDER_SPARK,
null,
CollationAwareUTF8String::compareLowerCase,
comparator,
"1.0",
s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(),
hashFunction,
(s1, s2) -> comparator.compare(s1, s2) == 0,
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ true);
/* supportsLowercaseEquality = */ true,
spaceTrimming != SpaceTrimming.NONE);
}
}

Expand Down Expand Up @@ -917,16 +957,35 @@ protected Collation buildCollation() {
Collator collator = Collator.getInstance(resultLocale);
// Freeze ICU collator to ensure thread safety.
collator.freeze();

Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;

if (spaceTrimming == SpaceTrimming.NONE) {
hashFunction = s -> (long) collator.getCollationKey(
s.toValidString()).hashCode();
comparator = (s1, s2) ->
collator.compare(s1.toValidString(), s2.toValidString());
} else {
comparator = (s1, s2) -> collator.compare(
applyTrimmingPolicy(s1, spaceTrimming).toValidString(),
applyTrimmingPolicy(s2, spaceTrimming).toValidString());
hashFunction = s -> (long) collator.getCollationKey(
applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode();
}

return new Collation(
normalizedCollationName(),
PROVIDER_ICU,
collator,
(s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()),
comparator,
ICU_COLLATOR_VERSION,
s -> (long) collator.getCollationKey(s.toValidString()).hashCode(),
hashFunction,
(s1, s2) -> comparator.compare(s1, s2) == 0,
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ false);
/* supportsLowercaseEquality = */ false,
spaceTrimming != SpaceTrimming.NONE);
}

@Override
Expand Down Expand Up @@ -1103,14 +1162,6 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
Collation.CollationSpecICU.AccentSensitivity.AI;
}

/**
* Returns whether the collation uses trim collation for the given collation id.
*/
public static boolean usesTrimCollation(int collationId) {
return Collation.CollationSpec.getSpaceTrimming(collationId) !=
Collation.CollationSpec.SpaceTrimming.NONE;
}

public static void assertValidProvider(String provider) throws SparkException {
if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) {
Map<String, String> params = Map.of(
Expand All @@ -1137,7 +1188,7 @@ public static String[] getICULocaleNames() {

public static UTF8String getCollationKey(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (usesTrimCollation(collationId)) {
if (collation.supportsSpaceTrimming) {
input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
}
if (collation.supportsBinaryEquality) {
Expand All @@ -1153,7 +1204,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) {

public static byte[] getCollationKeyBytes(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (usesTrimCollation(collationId)) {
if (collation.supportsSpaceTrimming) {
input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
}
if (collation.supportsBinaryEquality) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,42 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig
CollationTestCase("UTF8_BINARY", "aaa", "AAA", false),
CollationTestCase("UTF8_BINARY", "aaa", "bbb", false),
CollationTestCase("UTF8_BINARY", "å", "a\u030A", false),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false),
CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true),
CollationTestCase("UTF8_LCASE", "aaa", "aaa", true),
CollationTestCase("UTF8_LCASE", "aaa", "AAA", true),
CollationTestCase("UTF8_LCASE", "aaa", "AaA", true),
CollationTestCase("UTF8_LCASE", "aaa", "AaA", true),
CollationTestCase("UTF8_LCASE", "aaa", "aa", false),
CollationTestCase("UTF8_LCASE", "aaa", "bbb", false),
CollationTestCase("UTF8_LCASE", "å", "a\u030A", false),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false),
CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true),
CollationTestCase("UNICODE", "aaa", "aaa", true),
CollationTestCase("UNICODE", "aaa", "AAA", false),
CollationTestCase("UNICODE", "aaa", "bbb", false),
CollationTestCase("UNICODE", "å", "a\u030A", true),
CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true),
CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true),
CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true),
CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false),
CollationTestCase("UNICODE_RTRIM", " ", " ", true),
CollationTestCase("UNICODE_CI", "aaa", "aaa", true),
CollationTestCase("UNICODE_CI", "aaa", "AAA", true),
CollationTestCase("UNICODE_CI", "aaa", "bbb", false),
CollationTestCase("UNICODE_CI", "å", "a\u030A", true),
CollationTestCase("UNICODE_CI", "Å", "a\u030A", true)
CollationTestCase("UNICODE_CI", "Å", "a\u030A", true),
CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true),
CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true),
CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false),
CollationTestCase("UNICODE_RTRIM", " ", " ", true)
)

checks.foreach(testCase => {
Expand All @@ -162,19 +182,48 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig
CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1),
CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1),
CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1),
CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1),
CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0),
CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0),
CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0),
CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0),
CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0),
CollationTestCase("UTF8_LCASE", "aaa", "aa", 1),
CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1),
CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1),
CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0),
CollationTestCase("UNICODE", "aaa", "aaa", 0),
CollationTestCase("UNICODE", "aaa", "AAA", -1),
CollationTestCase("UNICODE", "aaa", "bbb", -1),
CollationTestCase("UNICODE", "aaa", "BBB", -1),
CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0),
CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0),
CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1),
CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1),
CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1),
CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1),
CollationTestCase("UNICODE_RTRIM", " ", " ", 0),
CollationTestCase("UNICODE_CI", "aaa", "aaa", 0),
CollationTestCase("UNICODE_CI", "aaa", "AAA", 0),
CollationTestCase("UNICODE_CI", "aaa", "bbb", -1))
CollationTestCase("UNICODE_CI", "aaa", "bbb", -1),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0),
CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1),
CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1),
CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0)
)

checks.foreach(testCase => {
val collation = fetchCollation(testCase.collationName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
!CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId)

private[sql] def usesTrimCollation: Boolean =
CollationFactory.usesTrimCollation(collationId)
CollationFactory.fetchCollation(collationId).supportsSpaceTrimming

private[sql] def isUTF8BinaryCollation: Boolean =
collationId == CollationFactory.UTF8_BINARY_COLLATION_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression {

protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality) {
if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) {
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
Expand Down Expand Up @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction {
hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
case s: UTF8String =>
val st = dataType.asInstanceOf[StringType]
if (st.supportsBinaryEquality) {
if (st.supportsBinaryEquality && !st.usesTrimCollation) {
hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
} else {
val stringHash = CollationFactory
Expand Down Expand Up @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {

override protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality) {
if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) {
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ object UnsafeRowUtils {
* can lead to rows being semantically equal even though their binary representations differ).
*/
def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively {
case st: StringType => !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality
case st: StringType =>
val collation = CollationFactory.fetchCollation(st.collationId)
(!collation.supportsBinaryEquality || collation.supportsSpaceTrimming)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ abstract class HashMapGenerator(
${hashBytes(bytes)}
"""
}
case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()")
case st: StringType if !st.supportsBinaryEquality =>
case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation =>
hashBytes(s"$input.getBytes()")
case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation =>
hashLong(s"CollationFactory.fetchCollation(${st.collationId})" +
s".hashFunction.applyAsLong($input)")
case CalendarIntervalType => hashInt(s"$input.hashCode()")
Expand Down
Loading

0 comments on commit 74aed77

Please sign in to comment.