Skip to content

[SPARK-48440][SQL] Fix StringTranslate behaviour for non-UTF8_BINARY collations #46761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import com.ibm.icu.lang.UCharacter;
import com.ibm.icu.text.BreakIterator;
import com.ibm.icu.text.Collator;
import com.ibm.icu.text.RuleBasedCollator;
import com.ibm.icu.text.StringSearch;
import com.ibm.icu.util.ULocale;

Expand All @@ -26,8 +28,12 @@

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
import static org.apache.spark.unsafe.Platform.copyMemory;
import static org.apache.spark.unsafe.types.UTF8String.CodePointIteratorType;

import java.text.CharacterIterator;
import java.text.StringCharacterIterator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

/**
Expand Down Expand Up @@ -424,27 +430,58 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col
* @param codePoint The code point to convert to lowercase.
* @param sb The StringBuilder to append the lowercase character to.
*/
private static void lowercaseCodePoint(final int codePoint, final StringBuilder sb) {
if (codePoint == 0x0130) {
private static void appendLowercaseCodePoint(final int codePoint, final StringBuilder sb) {
int lowercaseCodePoint = getLowercaseCodePoint(codePoint);
if (lowercaseCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
sb.appendCodePoint(0x0069);
sb.appendCodePoint(0x0307);
} else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(lowercaseCodePoint);
}
}

/**
* `CODE_POINT_COMBINED_LOWERCASE_I_DOT` is an internal representation of the combined lowercase
* code point for ASCII lowercase letter i with an additional combining dot character (U+0307).
* This integer value is not a valid code point itself, but rather an artificial code point
* marker used to represent the two lowercase characters that are the result of converting the
* uppercase Turkish dotted letter I with a combining dot character (U+0130) to lowercase.
*/
private static final int CODE_POINT_LOWERCASE_I = 0x69;
private static final int CODE_POINT_COMBINING_DOT = 0x307;
private static final int CODE_POINT_COMBINED_LOWERCASE_I_DOT =
CODE_POINT_LOWERCASE_I << 16 | CODE_POINT_COMBINING_DOT;

/**
* Returns the lowercase version of the provided code point, with special handling for
* one-to-many case mappings (i.e. characters that map to multiple characters in lowercase) and
* context-insensitive case mappings (i.e. characters that map to different characters based on
* the position in the string relative to other characters in lowercase).
*/
private static int getLowercaseCodePoint(final int codePoint) {
if (codePoint == 0x0130) {
// Latin capital letter I with dot above is mapped to 2 lowercase characters.
return CODE_POINT_COMBINED_LOWERCASE_I_DOT;
}
else if (codePoint == 0x03C2) {
// Greek final and non-final capital letter sigma should be mapped the same.
sb.appendCodePoint(0x03C3);
// Greek final and non-final letter sigma should be mapped the same. This is achieved by
// mapping Greek small final sigma (U+03C2) to Greek small non-final sigma (U+03C3). Capital
// letter sigma (U+03A3) is mapped to small non-final sigma (U+03C3) in the `else` branch.
return 0x03C3;
}
else {
// All other characters should follow context-unaware ICU single-code point case mapping.
sb.appendCodePoint(UCharacter.toLowerCase(codePoint));
return UCharacter.toLowerCase(codePoint);
}
}

/**
* Converts an entire string to lowercase using ICU rules, code point by code point, with
* special handling for one-to-many case mappings (i.e. characters that map to multiple
* characters in lowercase). Also, this method omits information about context-sensitive case
* mappings using special handling in the `lowercaseCodePoint` method.
* mappings using special handling in the `appendLowercaseCodePoint` method.
*
* @param target The target string to convert to lowercase.
* @return The string converted to lowercase in a context-unaware manner.
Expand All @@ -455,10 +492,11 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) {
}

private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
String targetString = target.toValidString();
Iterator<Integer> targetIter = target.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < targetString.length(); ++i) {
lowercaseCodePoint(targetString.codePointAt(i), sb);
while (targetIter.hasNext()) {
appendLowercaseCodePoint(targetIter.next(), sb);
}
return UTF8String.fromString(sb.toString());
}
Expand Down Expand Up @@ -655,38 +693,152 @@ public static UTF8String lowercaseSubStringIndex(final UTF8String string,
}
}

public static Map<String, String> getCollationAwareDict(UTF8String string,
Map<String, String> dict, int collationId) {
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
String srcStr = string.toString();
/**
* Converts the original translation dictionary (`dict`) to a dictionary with lowercased keys.
* This method is used to create a dictionary that can be used for the UTF8_LCASE collation.
* Note that `StringTranslate.buildDict` will ensure that all strings are validated properly.
*
* The method returns a map with lowercased code points as keys, while the values remain
* unchanged. Note that `dict` is constructed on a character by character basis, and the
* original keys are stored as strings. Keys in the resulting lowercase dictionary are stored
* as integers, which correspond only to single characters from the original `dict`. Also,
* there is special handling for the Turkish dotted uppercase letter I (U+0130).
*/
private static Map<Integer, String> getLowercaseDict(final Map<String, String> dict) {
// Replace all the keys in the dict with lowercased code points.
Map<Integer, String> lowercaseDict = new HashMap<>();
for (Map.Entry<String, String> entry : dict.entrySet()) {
int codePoint = entry.getKey().codePointAt(0);
lowercaseDict.putIfAbsent(getLowercaseCodePoint(codePoint), entry.getValue());
}
return lowercaseDict;
}

/**
* Translates the `input` string using the translation map `dict`, for UTF8_LCASE collation.
* String translation is performed by iterating over the input string, from left to right, and
* repeatedly translating the longest possible substring that matches a key in the dictionary.
* For UTF8_LCASE, the method uses the lowercased substring to perform the lookup in the
* lowercased version of the translation map.
*
* @param input the string to be translated
* @param dict the lowercase translation dictionary
* @return the translated string
*/
public static UTF8String lowercaseTranslate(final UTF8String input,
final Map<String, String> dict) {
// Iterator for the input string.
Iterator<Integer> inputIter = input.codePointIterator(
CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID);
// Lowercased translation dictionary.
Map<Integer, String> lowercaseDict = getLowercaseDict(dict);
// StringBuilder to store the translated string.
StringBuilder sb = new StringBuilder();

Map<String, String> collationAwareDict = new HashMap<>();
for (String key : dict.keySet()) {
StringSearch stringSearch =
CollationFactory.getStringSearch(string, UTF8String.fromString(key), collationId);
// We use buffered code point iteration to handle one-to-many case mappings. We need to handle
// at most two code points at a time (for `CODE_POINT_COMBINED_LOWERCASE_I_DOT`), a buffer of
// size 1 enables us to match two codepoints in the input string with a single codepoint in
// the lowercase translation dictionary.
int codePointBuffer = -1, codePoint;
while (inputIter.hasNext()) {
if (codePointBuffer != -1) {
codePoint = codePointBuffer;
codePointBuffer = -1;
} else {
codePoint = inputIter.next();
}
// Special handling for letter i (U+0069) followed by a combining dot (U+0307). By ensuring
// that `CODE_POINT_LOWERCASE_I` is buffered, we guarantee finding a max-length match.
if (lowercaseDict.containsKey(CODE_POINT_COMBINED_LOWERCASE_I_DOT) &&
codePoint == CODE_POINT_LOWERCASE_I && inputIter.hasNext()) {
int nextCodePoint = inputIter.next();
if (nextCodePoint == CODE_POINT_COMBINING_DOT) {
codePoint = CODE_POINT_COMBINED_LOWERCASE_I_DOT;
} else {
codePointBuffer = nextCodePoint;
}
}
// Translate the code point using the lowercased dictionary.
String translated = lowercaseDict.get(getLowercaseCodePoint(codePoint));
if (translated == null) {
// Append the original code point if no translation is found.
sb.appendCodePoint(codePoint);
} else if (!"\0".equals(translated)) {
// Append the translated code point if the translation is not the null character.
sb.append(translated);
}
// Skip the code point if it maps to the null character.
}
// Append the last code point if it was buffered.
if (codePointBuffer != -1) sb.appendCodePoint(codePointBuffer);

int pos = 0;
while ((pos = stringSearch.next()) != StringSearch.DONE) {
int codePoint = srcStr.codePointAt(pos);
int charCount = Character.charCount(codePoint);
String newKey = srcStr.substring(pos, pos + charCount);
// Return the translated string.
return UTF8String.fromString(sb.toString());
}

boolean exists = false;
for (String existingKey : collationAwareDict.keySet()) {
if (stringSearch.getCollator().compare(existingKey, newKey) == 0) {
collationAwareDict.put(newKey, collationAwareDict.get(existingKey));
exists = true;
break;
/**
* Translates the `input` string using the translation map `dict`, for all ICU collations.
* String translation is performed by iterating over the input string, from left to right, and
* repeatedly translating the longest possible substring that matches a key in the dictionary.
* For ICU collations, the method uses the ICU `StringSearch` class to perform the lookup in
* the translation map, while respecting the rules of the specified ICU collation.
*
* @param input the string to be translated
* @param dict the collation aware translation dictionary
* @param collationId the collation ID to use for string translation
* @return the translated string
*/
public static UTF8String translate(final UTF8String input,
final Map<String, String> dict, final int collationId) {
// Replace invalid UTF-8 sequences with the Unicode replacement character U+FFFD.
String inputString = input.toValidString();
// Create a character iterator for the validated input string. This will be used for searching
// inside the string using ICU `StringSearch` class. We only need to do it once before the
// main loop of the translate algorithm.
CharacterIterator target = new StringCharacterIterator(inputString);
Collator collator = CollationFactory.fetchCollation(collationId).collator;
StringBuilder sb = new StringBuilder();
// Index for the current character in the (validated) input string. This is the character we
// want to determine if we need to replace or not.
int charIndex = 0;
while (charIndex < inputString.length()) {
// We search the replacement dictionary to find a match. If there are more than one matches
// (which is possible for collated strings), we want to choose the match of largest length.
int longestMatchLen = 0;
String longestMatch = "";
for (String key : dict.keySet()) {
StringSearch stringSearch = new StringSearch(key, target, (RuleBasedCollator) collator);
// Point `stringSearch` to start at the current character.
stringSearch.setIndex(charIndex);
int matchIndex = stringSearch.next();
if (matchIndex == charIndex) {
// We have found a match (that is the current position matches with one of the characters
// in the dictionary). However, there might be other matches of larger length, so we need
// to continue searching against the characters in the dictionary and keep track of the
// match of largest length.
int matchLen = stringSearch.getMatchLength();
if (matchLen > longestMatchLen) {
longestMatchLen = matchLen;
longestMatch = key;
}
}

if (!exists) {
collationAwareDict.put(newKey, dict.get(key));
}
if (longestMatchLen == 0) {
// No match was found, so output the current character.
sb.append(inputString.charAt(charIndex));
// Move on to the next character in the input string.
++charIndex;
} else {
// We have found at least one match. Append the match of longest match length to the output.
if (!"\0".equals(dict.get(longestMatch))) {
sb.append(dict.get(longestMatch));
}
// Skip as many characters as the longest match.
charIndex += longestMatchLen;
}
}

return collationAwareDict;
// Return the translated string.
return UTF8String.fromString(sb.toString());
}

public static UTF8String lowercaseTrim(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean
return useICU ? execBinaryICU(v) : execBinary(v);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(v);
} else {
} else {
return execICU(v, collationId);
}
}
Expand All @@ -224,7 +224,7 @@ public static String genCode(final String v, final int collationId, boolean useI
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
} else {
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
Expand Down Expand Up @@ -261,7 +261,7 @@ public static String genCode(final String v, final int collationId, boolean useI
return String.format(expr + "%s(%s)", funcName, v);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s)", v);
} else {
} else {
return String.format(expr + "ICU(%s, %d)", v, collationId);
}
}
Expand Down Expand Up @@ -522,26 +522,11 @@ public static UTF8String execBinary(final UTF8String source, Map<String, String>
return source.translate(dict);
}
public static UTF8String execLowercase(final UTF8String source, Map<String, String> dict) {
String srcStr = source.toString();
StringBuilder sb = new StringBuilder();
int charCount = 0;
for (int k = 0; k < srcStr.length(); k += charCount) {
int codePoint = srcStr.codePointAt(k);
charCount = Character.charCount(codePoint);
String subStr = srcStr.substring(k, k + charCount);
String translated = dict.get(subStr.toLowerCase());
if (null == translated) {
sb.append(subStr);
} else if (!"\0".equals(translated)) {
sb.append(translated);
}
}
return UTF8String.fromString(sb.toString());
return CollationAwareUTF8String.lowercaseTranslate(source, dict);
}
public static UTF8String execICU(final UTF8String source, Map<String, String> dict,
final int collationId) {
return source.translate(CollationAwareUTF8String.getCollationAwareDict(
source, dict, collationId));
return CollationAwareUTF8String.translate(source, dict, collationId);
}
}

Expand Down
Loading