Skip to content

Commit

Permalink
Further improved performance by improving the parsing logic so that s…
Browse files Browse the repository at this point in the history
…trings for city names are not allocated with each row. (#323)

Co-authored-by: Bruno Felix <bruno.felix@klarna.com>
  • Loading branch information
felix19350 and Bruno Felix authored Jan 14, 2024
1 parent 990f884 commit bb5679f
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 88 deletions.
12 changes: 11 additions & 1 deletion calculate_average_felix19350.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
# limitations under the License.
#

# ParallelGC test - Time (measured by evaluate2.sh): 00:33.130
# JAVA_OPTS="--enable-preview -XX:+UseParallelGC -XX:+UseTransparentHugePages"

# G1GC test - Time (measured by evaluate2.sh): 00:26.447
# JAVA_OPTS="--enable-preview -XX:+UseG1GC -XX:+UseTransparentHugePages"

# ZGC test - Time (measured by evaluate2.sh): 00:22.813
JAVA_OPTS="--enable-preview -XX:+UseZGC -XX:+UseTransparentHugePages"

# EpsilonGC test - for now doesnt work because heap space gets exhausted
#JAVA_OPTS="--enable-preview -XX:+UnlockExperimentalVMOptions -XX:+UseEpsilonGC -XX:+AlwaysPreTouch"

JAVA_OPTS="--enable-preview -XX:+UseParallelGC -Xms4g -Xmx4g"
java $JAVA_OPTS --class-path target/average-1.0.0-SNAPSHOT.jar dev.morling.onebrc.CalculateAverage_felix19350
230 changes: 143 additions & 87 deletions src/main/java/dev/morling/onebrc/CalculateAverage_felix19350.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@
package dev.morling.onebrc;

import java.io.IOException;
import java.io.RandomAccessFile;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
Expand All @@ -36,6 +35,55 @@ public class CalculateAverage_felix19350 {
private static final String FILE = "./measurements.txt";
private static final int NEW_LINE_SEEK_BUFFER_LEN = 128;

private static final int EXPECTED_MAX_NUM_CITIES = 15_000; // 10K cities + a buffer no to trigger the load factor

private static class CityRef {

final int length;
final int fingerprint;
final byte[] stringBytes;

public CityRef(ByteBuffer byteBuffer, int startIdx, int length, int fingerprint) {
this.length = length;
this.stringBytes = new byte[length];
byteBuffer.get(startIdx, this.stringBytes, 0, this.stringBytes.length);
this.fingerprint = fingerprint;
}

public String cityName() {
return new String(stringBytes, StandardCharsets.UTF_8);
}

@Override
public int hashCode() {
return fingerprint;
}

@Override
public boolean equals(Object other) {
if (other instanceof CityRef otherRef) {
if (fingerprint != otherRef.fingerprint) {
return false;
}

if (this.length != otherRef.length) {
return false;
}

for (var i = 0; i < this.length; i++) {
if (this.stringBytes[i] != otherRef.stringBytes[i]) {
return false;
}
}
return true;
}
else {
return false;
}
}

}

private static class ResultRow {

private int min;
Expand Down Expand Up @@ -73,95 +121,104 @@ public void mergeResult(ResultRow value) {
}
}

private record AverageAggregatorTask(MemorySegment memSegment) {
private record AverageAggregatorTask(ByteBuffer byteBuffer) {
private static final int HASH_FACTOR = 31; // Mersenne prime

public static Stream<AverageAggregatorTask> createStreamOf(List<MemorySegment> memorySegments) {
return memorySegments.stream().map(AverageAggregatorTask::new);

public static Stream<AverageAggregatorTask> createStreamOf(List<ByteBuffer> byteBuffers) {
return byteBuffers.stream().map(AverageAggregatorTask::new);
}

public Map<String, ResultRow> processChunk() {
final var result = new TreeMap<String, ResultRow>();
var offset = 0L;
var lineStart = 0L;
while (offset < memSegment.byteSize()) {
byte nextByte = memSegment.get(ValueLayout.OfByte.JAVA_BYTE, offset);
if ((char) nextByte == '\n') {
this.processLine(result, memSegment.asSlice(lineStart, (offset - lineStart)).asByteBuffer());
lineStart = offset + ValueLayout.JAVA_BYTE.byteSize();
}
offset += ValueLayout.OfByte.JAVA_BYTE.byteSize();
public Map<CityRef, ResultRow> processChunk() {
final var measurements = new HashMap<CityRef, ResultRow>(EXPECTED_MAX_NUM_CITIES);
var lineStart = 0;
// process line by line playing with the fact that a line is no longer than 106 bytes
// 100 bytes for city name + 1 byte for separator + 1 bytes for negative sign + 4 bytes for number
while (lineStart < byteBuffer.limit()) {
lineStart = this.processLine(measurements, byteBuffer, lineStart);
}

return result;
return measurements;
}

private void processLine(Map<String, ResultRow> result, ByteBuffer lineBytes) {
private int processLine(Map<CityRef, ResultRow> measurements, ByteBuffer byteBuffer, int start) {
var fingerPrint = 0;
var separatorIdx = -1;
for (int i = 0; i < lineBytes.limit(); i++) {
if ((char) lineBytes.get() == ';') {
separatorIdx = i;
lineBytes.clear();
break;
var sign = 1;
var value = 0;
var lineEnd = -1;
// Lines are processed in two stages:
// 1 - prior do the city name separator
// 2 - after the separator
// this ensures less if clauses

// stage 1 loop
{
for (int i = 0; i < NEW_LINE_SEEK_BUFFER_LEN; i++) {
final var currentByte = byteBuffer.get(start + i);
if (currentByte == ';') {
separatorIdx = i;
break;
} else {
fingerPrint = HASH_FACTOR * fingerPrint + currentByte;
}
}
}
assert (separatorIdx > 0);

var valueCapacity = lineBytes.capacity() - (separatorIdx + 1);
var cityBytes = new byte[separatorIdx];
var valueBytes = new byte[valueCapacity];
lineBytes.get(cityBytes, 0, separatorIdx);
lineBytes.get(separatorIdx + 1, valueBytes);
// stage 2 loop:
{
for (int i = separatorIdx + 1; i < NEW_LINE_SEEK_BUFFER_LEN; i++) {
final var currentByte = byteBuffer.get(start + i);
switch (currentByte) {
case '-':
sign = -1;
break;
case '.':
break;
case '\n':
lineEnd = start + i + 1;
break;
default:
// only digits are expected here
value = value * 10 + (currentByte - '0');
}

if (lineEnd != -1) {
break;
}
}
}

var city = new String(cityBytes, StandardCharsets.UTF_8);
var value = parseInt(valueBytes);
assert (separatorIdx > 0);
final var cityRef = new CityRef(byteBuffer, start, separatorIdx,fingerPrint);
value = sign * value;

var latestValue = result.get(city);
if (latestValue != null) {
latestValue.mergeValue(value);
final var existingMeasurement = measurements.get(cityRef);
if (existingMeasurement == null) {
measurements.put(cityRef, new ResultRow(value));
} else {
result.put(city, new ResultRow(value));
existingMeasurement.mergeValue(value);
}
}

private static int parseInt(byte[] valueBytes) {
int multiplier = 1;
int digitValue = 0;
var numDigits = valueBytes.length-1; // there is always one decimal place
var ds = new int[]{1,10,100};

for (byte valueByte : valueBytes) {
switch ((char) valueByte) {
case '-':
multiplier = -1;
numDigits -= 1;
break;
case '.':
break;
default:
digitValue += ((int) valueByte - 48) * (ds[numDigits - 1]);
numDigits -= 1;
break;// TODO continue here
}
}
return multiplier*digitValue;
return lineEnd; //to account for the line end
}
}

public static void main(String[] args) throws IOException {
// memory map the files and divide by number of cores
var numProcessors = Runtime.getRuntime().availableProcessors();
var memorySegments = calculateMemorySegments(numProcessors);
var tasks = AverageAggregatorTask.createStreamOf(memorySegments);
assert (memorySegments.size() == numProcessors);
final var numProcessors = Runtime.getRuntime().availableProcessors();
final var byteBuffers = calculateMemorySegments(numProcessors);
final var tasks = AverageAggregatorTask.createStreamOf(byteBuffers);
assert (byteBuffers.size() <= numProcessors);
assert (!byteBuffers.isEmpty());

try (var pool = Executors.newFixedThreadPool(numProcessors)) {
var results = tasks
final Map<CityRef, ResultRow> aggregatedCities = tasks
.parallel()
.map(task -> CompletableFuture.supplyAsync(task::processChunk, pool))
.map(CompletableFuture::join)
.reduce(new TreeMap<>(), (partialMap, accumulator) -> {
partialMap.forEach((key, value) -> {
var prev = accumulator.get(key);
.reduce(new HashMap<>(EXPECTED_MAX_NUM_CITIES), (currentMap, accumulator) -> {
currentMap.forEach((key, value) -> {
final var prev = accumulator.get(key);
if (prev == null) {
accumulator.put(key, value);
}
Expand All @@ -172,6 +229,9 @@ public static void main(String[] args) throws IOException {
return accumulator;
});

var results = new HashMap<String, ResultRow>(EXPECTED_MAX_NUM_CITIES);
aggregatedCities.forEach((key, value) -> results.put(key.cityName(), value));

System.out.print("{");
String output = results.keySet()
.stream()
Expand All @@ -183,16 +243,16 @@ public static void main(String[] args) throws IOException {
}
}

private static List<MemorySegment> calculateMemorySegments(int numChunks) throws IOException {
try (RandomAccessFile raf = new RandomAccessFile(FILE, "r")) {
var result = new ArrayList<MemorySegment>(numChunks);
var chunks = new ArrayList<long[]>(numChunks);
private static List<ByteBuffer> calculateMemorySegments(int numChunks) throws IOException {
try (FileChannel fc = FileChannel.open(Paths.get(FILE))) {
var memMappedFile = fc.map(FileChannel.MapMode.READ_ONLY, 0L, fc.size(), Arena.ofAuto());
var result = new ArrayList<ByteBuffer>(numChunks);

var fileSize = raf.length();
var chunkSize = fileSize / numChunks;
var fileSize = fc.size();
var chunkSize = fileSize / numChunks; // TODO: if chunksize > MAX INT we will need to adjust
var previousChunkEnd = 0L;

for (int i = 0; i < numChunks; i++) {
var previousChunkEnd = i == 0 ? 0L : chunks.get(i - 1)[1];
if (previousChunkEnd >= fileSize) {
// There is a scenario for very small files where the number of chunks may be greater than
// the number of lines.
Expand All @@ -205,31 +265,27 @@ private static List<MemorySegment> calculateMemorySegments(int numChunks) throws
}
else {
// all other chunks are end at a new line (\n)
var theoreticalEnd = previousChunkEnd + chunkSize;
var buffer = new byte[NEW_LINE_SEEK_BUFFER_LEN];
raf.seek(theoreticalEnd);
raf.read(buffer, 0, NEW_LINE_SEEK_BUFFER_LEN);

var theoreticalEnd = Math.min(previousChunkEnd + chunkSize, fileSize);
var newLineOffset = 0;
for (byte b : buffer) {
for (int j = 0; j < NEW_LINE_SEEK_BUFFER_LEN; j++) {
var candidateOffset = theoreticalEnd + j;
if (candidateOffset >= fileSize) {
break;
}
byte b = memMappedFile.get(ValueLayout.OfByte.JAVA_BYTE, candidateOffset);
newLineOffset += 1;
if ((char) b == '\n') {
break;
}
}
chunk[1] = Math.min(fileSize, theoreticalEnd + newLineOffset);
previousChunkEnd = chunk[1];
}

assert (chunk[0] >= 0L);
assert (chunk[0] <= fileSize);
assert (chunk[1] > chunk[0]);
assert (chunk[1] <= fileSize);

var memMappedFile = raf.getChannel()
.map(FileChannel.MapMode.READ_ONLY, chunk[0], (chunk[1] - chunk[0]), Arena.ofAuto());
memMappedFile.load();
chunks.add(chunk);
result.add(memMappedFile);
result.add(memMappedFile.asSlice(chunk[0], (chunk[1] - chunk[0])).asByteBuffer());
}
return result;
}
Expand Down

0 comments on commit bb5679f

Please sign in to comment.