Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

memory mapped files, branchless parsing, bitwiddle magic #5

Merged
merged 10 commits into from
Jan 3, 2024
8 changes: 7 additions & 1 deletion calculate_average_royvanrijn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@
#


JAVA_OPTS=""
# Added for fun, doesn't seem to be making a difference...
if [ -f "target/calculate_average_royvanrijn.jsa" ]; then
JAVA_OPTS="-XX:SharedArchiveFile=target/calculate_average_royvanrijn.jsa -Xshare:on"
else
# First run, create the archive:
JAVA_OPTS="-XX:ArchiveClassesAtExit=target/calculate_average_royvanrijn.jsa"
fi
time java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_royvanrijn
312 changes: 278 additions & 34 deletions src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,309 @@
*/
package dev.morling.onebrc;

import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.AbstractMap;
import java.util.Map;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.Collectors;

/**
* Changelog:
*
* Initial submission: 62000 ms
* Chunked reader: 16000 ms
* Optimized parser: 13000 ms
* Branchless methods: 11000 ms
* Adding memory mapped files: 6500 ms (based on bjhara's submission)
* Skipping string creation: 4700 ms
* Custom hashmap... 4200 ms
* Added SWAR token checks: 3900 ms
* Skipped String creation: 3500 ms (idea from kgonia)
* Improved String skip: 3250 ms
* Segmenting files: 3150 ms (based on spullara's code)
royvanrijn marked this conversation as resolved.
Show resolved Hide resolved
* Not using SWAR for EOL: 2850 ms
*
* Best performing JVM on MacBook M2 Pro: 21.0.1-graal
* `sdk use java 21.0.1-graal`
*
*/
public class CalculateAverage_royvanrijn {

private static final String FILE = "./measurements.txt";

private record Measurement(double min, double max, double sum, long count) {
royvanrijn marked this conversation as resolved.
Show resolved Hide resolved
// mutable state now instead of records, ugh, less instantiation.
static final class Measurement {
int min, max, count;
long sum;

Measurement(double initialMeasurement) {
this(initialMeasurement, initialMeasurement, initialMeasurement, 1);
public Measurement() {
this.min = 10000;
this.max = -10000;
}

public static Measurement combineWith(Measurement m1, Measurement m2) {
return new Measurement(
m1.min < m2.min ? m1.min : m2.min,
m1.max > m2.max ? m1.max : m2.max,
m1.sum + m2.sum,
m1.count + m2.count
);
public Measurement updateWith(int measurement) {
min = min(min, measurement);
max = max(max, measurement);
sum += measurement;
count++;
return this;
}

public Measurement updateWith(Measurement measurement) {
min = min(min, measurement.min);
max = max(max, measurement.max);
sum += measurement.sum;
count += measurement.count;
return this;
}

public String toString() {
return round(min) + "/" + round(sum / count) + "/" + round(max);
return round(min) + "/" + round((1.0 * sum) / count) + "/" + round(max);
}

private double round(double value) {
return Math.round(value * 10.0) / 10.0;
return Math.round(value) / 10.0;
}
}

public static void main(String[] args) throws IOException {
public static final void main(String[] args) throws Exception {
new CalculateAverage_royvanrijn().run();
}

private void run() throws Exception {

var results = getFileSegments(new File(FILE)).stream().map(segment -> {

long segmentEnd = segment.end();
try (var fileChannel = (FileChannel) Files.newByteChannel(Path.of(FILE), StandardOpenOption.READ)) {
var bb = fileChannel.map(FileChannel.MapMode.READ_ONLY, segment.start(), segmentEnd - segment.start());
var buffer = new byte[64];

// Force little endian:
bb.order(ByteOrder.LITTLE_ENDIAN);

// long before = System.currentTimeMillis();
BitTwiddledMap measurements = new BitTwiddledMap();

Map<String, Measurement> resultMap = Files.lines(Path.of(FILE)).parallel()
.map(record -> {
// Map to <String,double>
int pivot = record.indexOf(";");
String key = record.substring(0, pivot);
double measured = Double.parseDouble(record.substring(pivot + 1));
return new AbstractMap.SimpleEntry<>(key, measured);
})
.collect(Collectors.toConcurrentMap(
// Combine/reduce:
AbstractMap.SimpleEntry::getKey,
entry -> new Measurement(entry.getValue()),
Measurement::combineWith));
int startPointer;
int limit = bb.limit();
while ((startPointer = bb.position()) < limit) {

System.out.print("{");
System.out.print(
resultMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Object::toString).collect(Collectors.joining(", ")));
System.out.println("}");
// SWAR is faster for ';'
int separatorPointer = findNextSWAR(bb, SEPARATOR_PATTERN, startPointer + 3, limit);

// System.out.println("Took: " + (System.currentTimeMillis() - before));
// Simple is faster for '\n' (just three options)
int endPointer;
if (bb.get(separatorPointer + 4) == '\n') {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my input I get IOOBE here:

Exception in thread "main" java.lang.IndexOutOfBoundsException
	at java.base/jdk.internal.reflect.DirectConstructorHandleAccessor.newInstance(DirectConstructorHandleAccessor.java:62)
	at java.base/java.lang.reflect.Constructor.newInstanceWithCaller(Constructor.java:502)
	at java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:486)
	at java.base/java.util.concurrent.ForkJoinTask.getThrowableException(ForkJoinTask.java:542)
	at java.base/java.util.concurrent.ForkJoinTask.reportException(ForkJoinTask.java:567)
	at java.base/java.util.concurrent.ForkJoinTask.invoke(ForkJoinTask.java:670)
	at java.base/java.util.stream.ReduceOps$ReduceOp.evaluateParallel(ReduceOps.java:927)
	at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:233)
	at java.base/java.util.stream.ReferencePipeline.collect(ReferencePipeline.java:682)
	at dev.morling.onebrc.CalculateAverage_royvanrijn.run(CalculateAverage_royvanrijn.java:144)
	at dev.morling.onebrc.CalculateAverage_royvanrijn.main(CalculateAverage_royvanrijn.java:92)
Caused by: java.lang.IndexOutOfBoundsException
	at java.base/java.nio.Buffer$1.apply(Buffer.java:757)
	at java.base/java.nio.Buffer$1.apply(Buffer.java:754)
	at java.base/jdk.internal.util.Preconditions$4.apply(Preconditions.java:213)
	at java.base/jdk.internal.util.Preconditions$4.apply(Preconditions.java:210)
	at java.base/jdk.internal.util.Preconditions.outOfBounds(Preconditions.java:98)
	at java.base/jdk.internal.util.Preconditions.outOfBoundsCheckIndex(Preconditions.java:106)
	at java.base/jdk.internal.util.Preconditions.checkIndex(Preconditions.java:302)
	at java.base/java.nio.Buffer.checkIndex(Buffer.java:768)
	at java.base/java.nio.DirectByteBuffer.get(DirectByteBuffer.java:358)
	at dev.morling.onebrc.CalculateAverage_royvanrijn.lambda$run$0(CalculateAverage_royvanrijn.java:118)
	at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
	at java.base/java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1708)
	at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
	at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
	at java.base/java.util.stream.ReduceOps$ReduceTask.doLeaf(ReduceOps.java:960)
	at java.base/java.util.stream.ReduceOps$ReduceTask.doLeaf(ReduceOps.java:934)
	at java.base/java.util.stream.AbstractTask.compute(AbstractTask.java:327)
	at java.base/java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:754)
	at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:387)
	at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1312)
	at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1843)
	at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1808)
	at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:188)

Let me know if you need the file.

Copy link
Contributor Author

@royvanrijn royvanrijn Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes please, perhaps there is some other bug that's platform dependent; do share.

Can you specify what platform you're running it on, and could you please also check this (improved) version:
https://github.com/royvanrijn/1brc/blob/8db31e6a36fbc305765a2393efb06ba6bff23f42/src/main/java/dev/morling/onebrc/CalculateAverage_royvanrijn.java

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running on Windows. Will check the new version later (it is quite late here now and starting from tomorrow I won't have access to PC for the weekend).

Let me know if you can suggest how to send you the data file - I started bzipping it and it takes forever, but even part way through, the archive is 2Gb (you can mail me upload coordinates at dimitar.dimitrov at gmail dot com)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, can you narrow it down? Perhaps run a very small test? Do they all crash, just this one?

Copy link

@DamienOReilly DamienOReilly Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ddimtirov for future reference, the default compression level on zstd will be a lot faster and offer reasonable compression:

On a MacBook Pro 2020 - 2 GHz Quad-Core Intel Core i5:

# time zstd -z measurements.txt                                                                                                                                                                                    8s
measurements.txt     : 28.24%   (  12.8 GiB =>   3.63 GiB, measurements.txt.zst)
zstd -z measurements.txt  92.10s user 8.05s system 107% cpu 1:33.54 total

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See related #61

We've also added some basic samples within #82

endPointer = separatorPointer + 4;
}
else if (bb.get(separatorPointer + 5) == '\n') {
endPointer = separatorPointer + 5;
}
else {
endPointer = separatorPointer + 6;
}

// Read the entry in a single get():
bb.get(buffer, 0, endPointer - startPointer);
bb.position(endPointer + 1); // skip to next line.

// Extract the measurement value (10x):
final int nameLength = separatorPointer - startPointer;
final int valueLength = endPointer - separatorPointer - 1;
final int measured = branchlessParseInt(buffer, nameLength + 1, valueLength);
measurements.getOrCreate(buffer, nameLength).updateWith(measured);
}
return measurements;
}
catch (IOException e) {
throw new RuntimeException(e);
}
}).parallel().flatMap(v -> v.values.stream())
.collect(Collectors.toMap(e -> new String(e.key), BitTwiddledMap.Entry::measurement, (m1, m2) -> m1.updateWith(m2), TreeMap::new));

// Seems to perform better than actually using a TreeMap:
System.out.println(results);
}

/**
* -------- This section contains SWAR code (SIMD Within A Register) which processes a bytebuffer as longs to find values:
*/
private static final long SEPARATOR_PATTERN = compilePattern((byte) ';');

private int findNextSWAR(ByteBuffer bb, long pattern, int start, int limit) {
int i;
for (i = start; i <= limit - 8; i += 8) {
long word = bb.getLong(i);
royvanrijn marked this conversation as resolved.
Show resolved Hide resolved
int index = firstAnyPattern(word, pattern);
if (index < Long.BYTES) {
return i + index;
}
}
// Handle remaining bytes
for (; i < limit; i++) {
if (bb.get(i) == (byte) pattern) {
royvanrijn marked this conversation as resolved.
Show resolved Hide resolved
return i;
}
}
return limit; // delimiter not found
}

private static long compilePattern(byte value) {
return ((long) value << 56) | ((long) value << 48) | ((long) value << 40) | ((long) value << 32) |
((long) value << 24) | ((long) value << 16) | ((long) value << 8) | (long) value;
}

private static int firstAnyPattern(long word, long pattern) {
final long match = word ^ pattern;
long mask = match - 0x0101010101010101L;
mask &= ~match;
mask &= 0x8080808080808080L;
return Long.numberOfTrailingZeros(mask) >>> 3;
Copy link

@franz1981 franz1981 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it shouldn't be the number of leading ones?
@royvanrijn @gunnarmorling

Copy link
Contributor Author

@royvanrijn royvanrijn Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, I thought by setting explicitly on 105 to LE would make the compatibility issues disappear. So running it on my machine would automatically mean it works on the target, although perhaps having a performance hit.

afk atm, I’ll check tomorrow, if somebody wants to fix it and tell me, be my guest 😂

Copy link

@franz1981 franz1981 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure actually, for these things I need an old school paper and a pencil :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some local classes that test; problem is that I believe it works on my machine, just not on the target machine, I’ll check soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, annoying, it runs fine locally (the code that was pushed), sigh. Kind of debugging in the dark haha... a challenge!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it twice, I'm wrong at #5

let's say we have a byte[] data = { 0x01, 0x03 }
And we assume to have a short-based version of SWAR

reading the content of data with (a short) little-endian means:

0x0301

which have the less significant part at the lower address,
hence the binary hex SWAR result obtained (I'm using the Netty algorithm, but here should be the same) will be

0x8000

and, in order to find out 0x03, we have to use the trailing zeros (here 8 + 7 = 15 -> 15/8 = 1) .

Which means that is fine as it is!

}

record FileSegment(long start, long end) {
}

/** Using this way to segment the file is much prettier, from spullara */
private static List<FileSegment> getFileSegments(File file) throws IOException {
final int numberOfSegments = Runtime.getRuntime().availableProcessors();
final long fileSize = file.length();
final long segmentSize = fileSize / numberOfSegments;
final List<FileSegment> segments = new ArrayList<>();
try (RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r")) {
for (int i = 0; i < numberOfSegments; i++) {
long segStart = i * segmentSize;
long segEnd = (i == numberOfSegments - 1) ? fileSize : segStart + segmentSize;
segStart = findSegment(i, 0, randomAccessFile, segStart, segEnd);
segEnd = findSegment(i, numberOfSegments - 1, randomAccessFile, segEnd, fileSize);

segments.add(new FileSegment(segStart, segEnd));
}
}
return segments;
}

private static long findSegment(int i, int skipSegment, RandomAccessFile raf, long location, long fileSize) throws IOException {
if (i != skipSegment) {
raf.seek(location);
while (location < fileSize) {
location++;
if (raf.read() == '\n')
return location;
}
}
return location;
}

/**
* Branchless parser, goes from String to int (10x):
* "-1.2" to -12
* "40.1" to 401
Comment on lines +221 to +223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gunnarmorling Is this assumption acceptable? E.g. 1, 0, 2.00 are all valid doubles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also there should be an acceptance test suite for the implementations, I am pretty sure this implementation does not produce the same output as the baseline.

Copy link
Contributor Author

@royvanrijn royvanrijn Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure it produces the exact same output, I check regularly with each change.

The input and output have one decimal place precision (as stated by the website "rounded to one fractional digit").

I'm internally storing the doubles as 10x integers because the precision is just a single digit, and I'm making sure the rounding is correct afterwards for the average.

Copy link
Contributor

@AlexanderYastrebov AlexanderYastrebov Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The README only says about output format but not input

1brc/README.md

Lines 27 to 28 in e7e7deb

The task is to write a Java program which reads the file, calculates the min, mean, and max temperature value per weather station, and emits the results on stdout like this
(i.e. sorted alphabetically by station name, and the result values per station in the format `<min>/<mean>/<max>`, rounded to one fractional digit):

so its worth clarifying.

There are many rounding modes - this is also not specified, e.g. in go https://pkg.go.dev/math#Round is not the same as in java.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also reference implementation does ad-hoc rounding Math.round(value * 10.0) / 10.0 but actual output depends on the string concatenation which performs another rounding, see https://docs.oracle.com/en/java/javase/21/docs/api/java.base/java/lang/Double.html#toString(double)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to unfortunate output format selection (see #14) one has to use word diff

Created #36 to address this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not being explicit enough here. Can the behavior of the reference implementation be described using any of the existing values of RoundingMode?

I think the the exact mode does not really matter but what matters is that the baseline is correct.

I propose to change baseline to use BigDecimal for results accumulation, use scale 1 (instead of round(x*10)/10) and HALF_UP rounding mode (as most common at school) at the final step:

BigDecimal value = new BigDecimal("12.34");
BigDecimal rounded = value.setScale(1, RoundingMode.HALF_UP);
  
System.out.println("=="+rounded.toString()+"=="); // prints ==12.3==

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's the thing, I don't think we can change the behavior of the reference implementation at this point, as it would render existing submissions invalid if they implement a different behavior. So I'd rather make the behavior of the RI explicit, also if it's not the most natural one (agreed that HALF_UP behavior would have been better).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think its possible to fix RI because it uses double division and rounds twice.
Since there are no acceptance tests I bet a lot of implementations (those that do not parse and calculate values the same way) will not match RI anyways.

I think RI should favor correctness over performance, then it can be used to build acceptance test suite.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've logged #49 for getting this one sorted out separately and get this PR merged. Let's continue the rounding topic over there. Thx!

* etc.
*
* @param input
* @return int value x10
*/
private static int branchlessParseInt(final byte[] input, int start, int length) {
// 0 if positive, 1 if negative
final int negative = ~(input[start] >> 4) & 1;
// 0 if nr length is 3, 1 if length is 4
final int has4 = ((length - negative) >> 2) & 1;

final int digit1 = input[start + negative] - '0';
final int digit2 = input[start + negative + has4];
final int digit3 = input[start + negative + has4 + 2];

return (-negative ^ (has4 * (digit1 * 100) + digit2 * 10 + digit3 - 528) - negative); // 528 == ('0' * 10 + '0')
}

// branchless max (unprecise for large numbers, but good enough)
static int max(final int a, final int b) {
final int diff = a - b;
final int dsgn = diff >> 31;
return a - (diff & dsgn);
}

// branchless min (unprecise for large numbers, but good enough)
static int min(final int a, final int b) {
final int diff = a - b;
final int dsgn = diff >> 31;
return b + (diff & dsgn);
}

/**
* A normal Java HashMap does all these safety things like boundary checks... we don't need that, we need speeeed.
*
* So I've written an extremely simple linear probing hashmap that should work well enough.
*/
class BitTwiddledMap {
private static final int SIZE = 16384; // A bit larger than the number of keys, needs power of two
private int[] indices = new int[SIZE]; // Hashtable is just an int[]

BitTwiddledMap() {
// Optimized fill with -1, fastest method:
int len = indices.length;
if (len > 0) {
indices[0] = -1;
}
// Value of i will be [1, 2, 4, 8, 16, 32, ..., len]
for (int i = 1; i < len; i += i) {
System.arraycopy(indices, 0, indices, i, i);
}
}

private List<Entry> values = new ArrayList<>(512);

record Entry(int hash, byte[] key, Measurement measurement) {
@Override
public String toString() {
return new String(key) + "=" + measurement;
}
}

/**
* Who needs methods like add(), merge(), compute() etc, we need one, getOrCreate.
* @param key
* @return
*/
public Measurement getOrCreate(byte[] key, int length) {
int inHash;
int index = (SIZE - 1) & (inHash = hashCode(key, length));
int valueIndex;
Entry retrievedEntry = null;
while ((valueIndex = indices[index]) != -1 && (retrievedEntry = values.get(valueIndex)).hash != inHash) {
index = (index + 1) % SIZE;
}
if (valueIndex >= 0) {
return retrievedEntry.measurement;
}
// New entry, insert into table and return.
indices[index] = values.size();

// Only parse this once:
byte[] actualKey = new byte[length];
System.arraycopy(key, 0, actualKey, 0, length);

Entry toAdd = new Entry(inHash, actualKey, new Measurement());
values.add(toAdd);
return toAdd.measurement;
}

private static int hashCode(byte[] a, int length) {
Copy link

@franz1981 franz1981 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hash code here has a data dependency: you either manually unroll this or just relax the hash code by using a var handle and use getLong amortizing the data dependency in batches, handing only the last 7 (or less) bytes separately, using the array.
In this way the most of computation would like resolve in much less loop iterations too, similar to https://github.com/apache/activemq-artemis/blob/25fc0342275b29cd73123523a46e6e94582597cd/artemis-commons/src/main/java/org/apache/activemq/artemis/utils/ByteUtil.java#L299

int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + a[i];
}
return result;
}
}

}
Loading