Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One last improvement for thomaswue #702

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 66 additions & 61 deletions src/main/java/dev/morling/onebrc/CalculateAverage_thomaswue.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
* split into 3 parts and cursors for each of those parts are processing the segment simultaneously in the same thread.
* Results are accumulated into {@link Result} objects and a tree map is used to sequentially accumulate the results in
* the end.
* Runs in 0.39s on an Intel i9-13900K.
* Runs in 0.31 on an Intel i9-13900K while the reference implementation takes 120.37s.
* Credit:
* Quan Anh Mai for branchless number parsing code
* Alfonso² Peterssen for suggesting memory mapping with unsafe and the subprocess idea
* Artsiom Korzun for showing the benefits of work stealing at 2MB segments instead of equal split between workers
* Jaromir Hamala for showing that avoiding the branch misprediction between <8 and 8-16 cases is a big win even if
* more work is performed
* Van Phu DO for demonstrating the lookup tables based on masks instead of bit shifting
*/
public class CalculateAverage_thomaswue {
private static final String FILE = "./measurements.txt";
Expand Down Expand Up @@ -141,9 +144,15 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart,
long delimiterMask1 = findDelimiter(word1);
long delimiterMask2 = findDelimiter(word2);
long delimiterMask3 = findDelimiter(word3);
Result existingResult1 = findResult(word1, delimiterMask1, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, scanner3, results, collectedResults);
long word1b = scanner1.getLongAt(scanner1.pos() + 8);
long word2b = scanner2.getLongAt(scanner2.pos() + 8);
long word3b = scanner3.getLongAt(scanner3.pos() + 8);
long delimiterMask1b = findDelimiter(word1b);
long delimiterMask2b = findDelimiter(word2b);
long delimiterMask3b = findDelimiter(word3b);
Result existingResult1 = findResult(word1, delimiterMask1, word1b, delimiterMask1b, scanner1, results, collectedResults);
Result existingResult2 = findResult(word2, delimiterMask2, word2b, delimiterMask2b, scanner2, results, collectedResults);
Result existingResult3 = findResult(word3, delimiterMask3, word3b, delimiterMask3b, scanner3, results, collectedResults);
long number1 = scanNumber(scanner1);
long number2 = scanNumber(scanner2);
long number3 = scanNumber(scanner3);
Expand All @@ -155,76 +164,70 @@ private static void parseLoop(AtomicLong counter, long fileEnd, long fileStart,
while (scanner1.hasNext()) {
long word = scanner1.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner1, results, collectedResults), scanNumber(scanner1));
long wordB = scanner1.getLongAt(scanner1.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner1, results, collectedResults), scanNumber(scanner1));
}
while (scanner2.hasNext()) {
long word = scanner2.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner2, results, collectedResults), scanNumber(scanner2));
long wordB = scanner2.getLongAt(scanner2.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner2, results, collectedResults), scanNumber(scanner2));
}
while (scanner3.hasNext()) {
long word = scanner3.getLong();
long pos = findDelimiter(word);
record(findResult(word, pos, scanner3, results, collectedResults), scanNumber(scanner3));
long wordB = scanner3.getLongAt(scanner3.pos() + 8);
long posB = findDelimiter(wordB);
record(findResult(word, pos, wordB, posB, scanner3, results, collectedResults), scanNumber(scanner3));
}
}
}

private static Result findResult(long initialWord, long initialDelimiterMask, Scanner scanner, Result[] results, List<Result> collectedResults) {
private static final long[] MASK1 = new long[]{ 0xFFL, 0xFFFFL, 0xFFFFFFL, 0xFFFFFFFFL, 0xFFFFFFFFFFL, 0xFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL,
0xFFFFFFFFFFFFFFFFL };
private static final long[] MASK2 = new long[]{ 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0x00L, 0xFFFFFFFFFFFFFFFFL };

private static Result findResult(long initialWord, long initialDelimiterMask, long wordB, long delimiterMaskB, Scanner scanner, Result[] results,
List<Result> collectedResults) {
Result existingResult;
long word = initialWord;
long delimiterMask = initialDelimiterMask;
long hash;
long nameAddress = scanner.pos();

// Search for ';', one long at a time. There are two common cases that a specially treated:
// (b) the ';' is found in the first 16 bytes
if (delimiterMask != 0) {
// Special case for when the ';' is found in the first 8 bytes.
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash = word;
long word2 = wordB;
long delimiterMask2 = delimiterMaskB;
if ((delimiterMask | delimiterMask2) != 0) {
int letterCount1 = Long.numberOfTrailingZeros(delimiterMask) >>> 3; // value between 1 and 8
int letterCount2 = Long.numberOfTrailingZeros(delimiterMask2) >>> 3; // value between 0 and 8
long mask = MASK2[letterCount1];
word = word & MASK1[letterCount1];
word2 = mask & word2 & MASK1[letterCount2];
hash = word ^ word2;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word) {
scanner.add(letterCount1 + (letterCount2 & mask));
if (existingResult != null && existingResult.firstNameWord == word && existingResult.secondNameWord == word2) {
return existingResult;
}
}
else {
// Special case for when the ';' is found in bytes 9-16.
hash = word;
long prevWord = word;
scanner.add(8);
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
existingResult = results[hashToIndex(hash, results)];
if (existingResult != null && existingResult.lastNameLong == word && existingResult.secondLastNameLong == prevWord) {
return existingResult;
// Slow-path for when the ';' could not be found in the first 16 bytes.
hash = word ^ word2;
scanner.add(16);
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
}
else {
// Slow-path for when the ';' could not be found in the first 16 bytes.
scanner.add(8);
hash ^= word;
while (true) {
word = scanner.getLong();
delimiterMask = findDelimiter(word);
if (delimiterMask != 0) {
int trailingZeros = Long.numberOfTrailingZeros(delimiterMask);
word = (word << (63 - trailingZeros));
scanner.add(trailingZeros >>> 3);
hash ^= word;
break;
}
else {
scanner.add(8);
hash ^= word;
}
else {
scanner.add(8);
hash ^= word;
}
}
}
Expand All @@ -249,8 +252,8 @@ private static Result findResult(long initialWord, long initialDelimiterMask, Sc
}
}

int remainingShift = (64 - (nameLength + 1 - i) << 3);
if (existingResult.lastNameLong == (scanner.getLongAt(nameAddress + i) << remainingShift)) {
int remainingShift = (64 - ((nameLength + 1 - i) << 3));
if (((scanner.getLongAt(existingResult.nameAddress + i) ^ (scanner.getLongAt(nameAddress + i))) << remainingShift) == 0) {
break;
}
else {
Expand Down Expand Up @@ -297,7 +300,7 @@ private static void record(Result existingResult, long number) {
}

private static int hashToIndex(long hash, Result[] results) {
long hashAsInt = hash ^ (hash >>> 37) ^ (hash >>> 17);
long hashAsInt = hash ^ (hash >>> 33) ^ (hash >>> 15);
return (int) (hashAsInt & (results.length - 1));
}

Expand All @@ -324,21 +327,23 @@ private static long findDelimiter(long word) {
private static Result newEntry(Result[] results, long nameAddress, int hash, int nameLength, Scanner scanner, List<Result> collectedResults) {
Result r = new Result();
results[hash] = r;
int i = 0;
for (; i < nameLength + 1 - Long.BYTES; i += Long.BYTES) {
int totalLength = nameLength + 1;
r.firstNameWord = scanner.getLongAt(nameAddress);
r.secondNameWord = scanner.getLongAt(nameAddress + 8);
if (totalLength <= 8) {
r.firstNameWord = r.firstNameWord & MASK1[totalLength - 1];
r.secondNameWord = 0;
}
if (nameLength + 1 > 8) {
r.secondLastNameLong = scanner.getLongAt(nameAddress + i - 8);
else if (totalLength < 16) {
r.secondNameWord = r.secondNameWord & MASK1[totalLength - 9];
}
int remainingShift = (64 - (nameLength + 1 - i) << 3);
r.lastNameLong = (scanner.getLongAt(nameAddress + i) << remainingShift);
r.nameAddress = nameAddress;
collectedResults.add(r);
return r;
}

private static final class Result {
long lastNameLong, secondLastNameLong;
long firstNameWord, secondNameWord;
short min, max;
int count;
long sum;
Expand Down
Loading