Skip to content

[SPARK-48283][SQL] Modify string comparison for UTF8_BINARY_LCASE #46700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,54 @@ private static int lowercaseRFind(
return MATCH_NOT_FOUND;
}

/**
* Lowercase UTF8String comparison used for UTF8_BINARY_LCASE collation. While the default
* UTF8String comparison is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase()), this
* method uses code points to compare the strings in a case-insensitive manner using ICU rules,
* as well as handling special rules for one-to-many case mappings (see: lowerCaseCodePoints).
*
* @param left The first UTF8String to compare.
* @param right The second UTF8String to compare.
* @return An integer representing the comparison result.
*/
public static int compareLowerCase(final UTF8String left, final UTF8String right) {
// Only if both strings are ASCII, we can use faster comparison (no string allocations).
if (left.isFullAscii() && right.isFullAscii()) {
return compareLowerCaseAscii(left, right);
}
return compareLowerCaseSlow(left, right);
}

/**
* Fast version of the `compareLowerCase` method, used when both arguments are ASCII strings.
*
* @param left The first ASCII UTF8String to compare.
* @param right The second ASCII UTF8String to compare.
* @return An integer representing the comparison result.
*/
private static int compareLowerCaseAscii(final UTF8String left, final UTF8String right) {
int leftBytes = left.numBytes(), rightBytes = right.numBytes();
for (int curr = 0; curr < leftBytes && curr < rightBytes; curr++) {
int lowerLeftByte = Character.toLowerCase(left.getByte(curr));
int lowerRightByte = Character.toLowerCase(right.getByte(curr));
if (lowerLeftByte != lowerRightByte) {
return lowerLeftByte - lowerRightByte;
}
}
return leftBytes - rightBytes;
}

/**
* Slow version of the `compareLowerCase` method, used when both arguments are non-ASCII strings.
*
* @param left The first non-ASCII UTF8String to compare.
* @param right The second non-ASCII UTF8String to compare.
* @return An integer representing the comparison result.
*/
private static int compareLowerCaseSlow(final UTF8String left, final UTF8String right) {
return lowerCaseCodePoints(left.toString()).compareTo(lowerCaseCodePoints(right.toString()));
}

public static UTF8String replace(final UTF8String src, final UTF8String search,
final UTF8String replace, final int collationId) {
// This collation aware implementation is based on existing implementation on UTF8String
Expand Down Expand Up @@ -296,6 +344,48 @@ public static String toLowerCase(final String target, final int collationId) {
return UCharacter.toLowerCase(locale, target);
}

/**
* Converts a single code point to lowercase using ICU rules, with special handling for
* one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and
* context-insensitive case mappings (i.e. characters that map to different characters based on
* string context - e.g. the position in the string relative to other characters).
*
* @param codePoint The code point to convert to lowercase.
* @param sb The StringBuilder to append the lowercase character to.
*/
private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) {
if (codePoint == 0x0130) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
sb.appendCodePoint(0x0069);
sb.appendCodePoint(0x0307);
}
else if (codePoint == 0x03C2) {
// Greek final and non-final capital letter sigma should be mapped the same.
sb.appendCodePoint(0x03C3);
}
else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
}
}

/**
* Converts an entire string to lowercase using ICU rules, code point by code point, with
* special handling for one-to-many case mappings (i.e. characters that map to multiple
* characters in lowercase). Also, this method omits information about context-sensitive case
* mappings using special handling in the `lowercaseCodePoint` method.
*
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
*/
public static String lowerCaseCodePoints(final String target) {
StringBuilder sb = new StringBuilder();
for (int i = 0; i < target.length(); ++i) {
lowercaseCodePoint(target.codePointAt(i), sb);
}
return sb.toString();
}

public static String toTitleCase(final String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,9 @@ protected Collation buildCollation() {
"UTF8_BINARY_LCASE",
PROVIDER_SPARK,
null,
UTF8String::compareLowerCase,
CollationAwareUTF8String::compareLowerCase,
"1.0",
s -> (long) s.toLowerCase().hashCode(),
s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s.toString()).hashCode(),
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,34 +388,6 @@ private UTF8String toUpperCaseSlow() {
return fromString(toString().toUpperCase());
}

/**
* Optimized lowercase comparison for UTF8_BINARY_LCASE collation
* a.compareLowerCase(b) is equivalent to a.toLowerCase().binaryCompare(b.toLowerCase())
*/
public int compareLowerCase(UTF8String other) {
int curr;
for (curr = 0; curr < numBytes && curr < other.numBytes; curr++) {
byte left, right;
if ((left = getByte(curr)) < 0 || (right = other.getByte(curr)) < 0) {
return compareLowerCaseSuffixSlow(other, curr);
}
int lowerLeft = Character.toLowerCase(left);
int lowerRight = Character.toLowerCase(right);
if (lowerLeft != lowerRight) {
return lowerLeft - lowerRight;
}
}
return numBytes - other.numBytes;
}

private int compareLowerCaseSuffixSlow(UTF8String other, int pref) {
UTF8String suffixLeft = UTF8String.fromAddress(base, offset + pref,
numBytes - pref);
UTF8String suffixRight = UTF8String.fromAddress(other.base, other.offset + pref,
other.numBytes - pref);
return suffixLeft.toLowerCaseSlow().binaryCompare(suffixRight.toLowerCaseSlow());
}

/**
* Returns the lower case of this string
*/
Expand All @@ -427,7 +399,7 @@ public UTF8String toLowerCase() {
return isFullAscii() ? toLowerCaseAscii() : toLowerCaseSlow();
}

private boolean isFullAscii() {
public boolean isFullAscii() {
for (var i = 0; i < numBytes; i++) {
if (getByte(i) < 0) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.unsafe.types;

import org.apache.spark.SparkException;
import org.apache.spark.sql.catalyst.util.CollationAwareUTF8String;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.sql.catalyst.util.CollationSupport;
import org.junit.jupiter.api.Test;
Expand All @@ -26,6 +27,156 @@
// checkstyle.off: AvoidEscapedUnicodeCharacters
public class CollationSupportSuite {

/**
* A list containing some of the supported collations in Spark. Use this list to iterate over
* all the important collation groups (binary, lowercase, icu) for complete unit test coverage.
* Note: this list may come in handy when the Spark function result is the same regardless of
* the specified collations (as often seen in some pass-through Spark expressions).
*/
private final String[] testSupportedCollations =
{"UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"};

/**
* Collation-aware UTF8String comparison.
*/

private void assertStringCompare(String s1, String s2, String collationName, int expected)
throws SparkException {
UTF8String l = UTF8String.fromString(s1);
UTF8String r = UTF8String.fromString(s2);
int compare = CollationFactory.fetchCollation(collationName).comparator.compare(l, r);
assertEquals(Integer.signum(expected), Integer.signum(compare));
}

@Test
public void testCompare() throws SparkException {
for (String collationName: testSupportedCollations) {
// Edge cases
assertStringCompare("", "", collationName, 0);
assertStringCompare("a", "", collationName, 1);
assertStringCompare("", "a", collationName, -1);
// Basic tests
assertStringCompare("a", "a", collationName, 0);
assertStringCompare("a", "b", collationName, -1);
assertStringCompare("b", "a", collationName, 1);
assertStringCompare("A", "A", collationName, 0);
assertStringCompare("A", "B", collationName, -1);
assertStringCompare("B", "A", collationName, 1);
assertStringCompare("aa", "a", collationName, 1);
assertStringCompare("b", "bb", collationName, -1);
assertStringCompare("abc", "a", collationName, 1);
assertStringCompare("abc", "b", collationName, -1);
assertStringCompare("abc", "ab", collationName, 1);
assertStringCompare("abc", "abc", collationName, 0);
// ASCII strings
assertStringCompare("aaaa", "aaa", collationName, 1);
assertStringCompare("hello", "world", collationName, -1);
assertStringCompare("Spark", "Spark", collationName, 0);
// Non-ASCII strings
assertStringCompare("ü", "ü", collationName, 0);
assertStringCompare("ü", "", collationName, 1);
assertStringCompare("", "ü", collationName, -1);
assertStringCompare("äü", "äü", collationName, 0);
assertStringCompare("äxx", "äx", collationName, 1);
assertStringCompare("a", "ä", collationName, -1);
}
// Non-ASCII strings
assertStringCompare("äü", "bü", "UTF8_BINARY", 1);
assertStringCompare("bxx", "bü", "UTF8_BINARY", -1);
assertStringCompare("äü", "bü", "UTF8_BINARY_LCASE", 1);
assertStringCompare("bxx", "bü", "UTF8_BINARY_LCASE", -1);
assertStringCompare("äü", "bü", "UNICODE", -1);
assertStringCompare("bxx", "bü", "UNICODE", 1);
assertStringCompare("äü", "bü", "UNICODE_CI", -1);
assertStringCompare("bxx", "bü", "UNICODE_CI", 1);
// Case variation
assertStringCompare("AbCd", "aBcD", "UTF8_BINARY", -1);
assertStringCompare("ABCD", "abcd", "UTF8_BINARY_LCASE", 0);
assertStringCompare("AbcD", "aBCd", "UNICODE", 1);
assertStringCompare("abcd", "ABCD", "UNICODE_CI", 0);
// Accent variation
assertStringCompare("aBćD", "ABĆD", "UTF8_BINARY", 1);
assertStringCompare("AbCδ", "ABCΔ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("äBCd", "ÄBCD", "UNICODE", -1);
assertStringCompare("Ab́cD", "AB́CD", "UNICODE_CI", 0);
// Case-variable character length
assertStringCompare("i\u0307", "İ", "UTF8_BINARY", -1);
assertStringCompare("İ", "i\u0307", "UTF8_BINARY", 1);
assertStringCompare("i\u0307", "İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İ", "i\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307", "İ", "UNICODE", -1);
assertStringCompare("İ", "i\u0307", "UNICODE", 1);
assertStringCompare("i\u0307", "İ", "UNICODE_CI", 0);
assertStringCompare("İ", "i\u0307", "UNICODE_CI", 0);
assertStringCompare("i\u0307İ", "i\u0307İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307İ", "İi\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İi\u0307", "i\u0307İ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("İi\u0307", "İi\u0307", "UTF8_BINARY_LCASE", 0);
assertStringCompare("i\u0307İ", "i\u0307İ", "UNICODE_CI", 0);
assertStringCompare("i\u0307İ", "İi\u0307", "UNICODE_CI", 0);
assertStringCompare("İi\u0307", "i\u0307İ", "UNICODE_CI", 0);
assertStringCompare("İi\u0307", "İi\u0307", "UNICODE_CI", 0);
// Conditional case mapping
assertStringCompare("ς", "σ", "UTF8_BINARY", -1);
assertStringCompare("ς", "Σ", "UTF8_BINARY", 1);
assertStringCompare("σ", "Σ", "UTF8_BINARY", 1);
assertStringCompare("ς", "σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("ς", "Σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("σ", "Σ", "UTF8_BINARY_LCASE", 0);
assertStringCompare("ς", "σ", "UNICODE", 1);
assertStringCompare("ς", "Σ", "UNICODE", 1);
assertStringCompare("σ", "Σ", "UNICODE", -1);
assertStringCompare("ς", "σ", "UNICODE_CI", 0);
assertStringCompare("ς", "Σ", "UNICODE_CI", 0);
assertStringCompare("σ", "Σ", "UNICODE_CI", 0);
}

private void assertLowerCaseCodePoints(UTF8String target, UTF8String expected,
Boolean useCodePoints) {
if (useCodePoints) {
assertEquals(expected.toString(),
CollationAwareUTF8String.lowerCaseCodePoints(target.toString()));
} else {
assertEquals(expected, target.toLowerCase());
}
}

@Test
public void testLowerCaseCodePoints() {
// Edge cases
assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), false);
assertLowerCaseCodePoints(UTF8String.fromString(""), UTF8String.fromString(""), true);
// Basic tests
assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("AbCd"), UTF8String.fromString("abcd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("abcd"), UTF8String.fromString("abcd"), true);
assertLowerCaseCodePoints(UTF8String.fromString("aBcD"), UTF8String.fromString("abcd"), true);
// Accent variation
assertLowerCaseCodePoints(UTF8String.fromString("AbĆd"), UTF8String.fromString("abćd"), false);
assertLowerCaseCodePoints(UTF8String.fromString("aBcΔ"), UTF8String.fromString("abcδ"), true);
// Case-variable character length
assertLowerCaseCodePoints(
UTF8String.fromString("İoDiNe"), UTF8String.fromString("i̇odine"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("Abi̇o12"), UTF8String.fromString("abi̇o12"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("İodInE"), UTF8String.fromString("i̇odine"), true);
assertLowerCaseCodePoints(
UTF8String.fromString("aBi̇o12"), UTF8String.fromString("abi̇o12"), true);
// Conditional case mapping
assertLowerCaseCodePoints(
UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινος"), false);
assertLowerCaseCodePoints(
UTF8String.fromString("ΘΑΛΑΣΣΙΝΟΣ"), UTF8String.fromString("θαλασσινοσ"), true);
// Surrogate pairs are treated as invalid UTF8 sequences
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), false);
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\ufffd\ufffd"), true);
}

/**
* Collation-aware string expressions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,29 +107,6 @@ public void binaryCompareTo() {
assertTrue(fromString("你好123").binaryCompare(fromString("你好122")) > 0);
}

@Test
public void lowercaseComparison() {
// SPARK-47693: Test optimized lowercase comparison of UTF8String instances
// ASCII
assertEquals(fromString("aaa").compareLowerCase(fromString("AAA")), 0);
assertTrue(fromString("aaa").compareLowerCase(fromString("AAAA")) < 0);
assertTrue(fromString("AAA").compareLowerCase(fromString("aaaa")) < 0);
assertTrue(fromString("a").compareLowerCase(fromString("B")) < 0);
assertTrue(fromString("b").compareLowerCase(fromString("A")) > 0);
assertEquals(fromString("aAa").compareLowerCase(fromString("AaA")), 0);
assertTrue(fromString("abcd").compareLowerCase(fromString("abC")) > 0);
assertTrue(fromString("ABC").compareLowerCase(fromString("abcd")) < 0);
assertEquals(fromString("abcd").compareLowerCase(fromString("abcd")), 0);
// non-ASCII
assertEquals(fromString("ü").compareLowerCase(fromString("Ü")), 0);
assertEquals(fromString("Äü").compareLowerCase(fromString("äÜ")), 0);
assertTrue(fromString("a").compareLowerCase(fromString("ä")) < 0);
assertTrue(fromString("a").compareLowerCase(fromString("Ä")) < 0);
assertTrue(fromString("A").compareLowerCase(fromString("ä")) < 0);
assertTrue(fromString("bä").compareLowerCase(fromString("aü")) > 0);
assertTrue(fromString("bxxxxxxxxxx").compareLowerCase(fromString("bü")) < 0);
}

protected static void testUpperandLower(String upper, String lower) {
UTF8String us = fromString(upper);
UTF8String ls = fromString(lower);
Expand Down