Skip to content

Commit 8c4c689

Browse files
authored
Enable 512-bit vector support (simdjson#20)
* 512-bit experiment * use preferred species * nwidth_array initial commit * nwidth_array, make classifier array static * WIP comparing nwidth switch vs method handle dispatch of ::step * WIP - nwitdh * candidate PR code * Fix step methods visibility, remove unused MethodHandle imports
1 parent dddb0d7 commit 8c4c689

9 files changed

+138
-77
lines changed

src/main/java/org/simdjson/CharactersClassifier.java

+38-35
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,42 @@
66
class CharactersClassifier {
77

88
private static final byte LOW_NIBBLE_MASK = 0x0f;
9-
private static final ByteVector WHITESPACE_TABLE = ByteVector.fromArray(
10-
ByteVector.SPECIES_256,
11-
new byte[]{
12-
' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100,
13-
' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100
14-
},
15-
0
16-
);
17-
private static final ByteVector OP_TABLE = ByteVector.fromArray(
18-
ByteVector.SPECIES_256,
19-
new byte[]{
20-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0,
21-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0
22-
},
23-
0
24-
);
9+
10+
private static final ByteVector WHITESPACE_TABLE =
11+
ByteVector.fromArray(
12+
StructuralIndexer.SPECIES,
13+
repeat(new byte[]{' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100}, StructuralIndexer.SPECIES.vectorByteSize() / 4),
14+
0);
15+
16+
private static final ByteVector OP_TABLE =
17+
ByteVector.fromArray(
18+
StructuralIndexer.SPECIES,
19+
repeat(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0}, StructuralIndexer.SPECIES.vectorByteSize() / 4),
20+
0);
21+
22+
private static byte[] repeat(byte[] array, int n) {
23+
byte[] result = new byte[n * array.length];
24+
for (int dst = 0; dst < result.length; dst += array.length) {
25+
System.arraycopy(array, 0, result, dst, array.length);
26+
}
27+
return result;
28+
}
29+
30+
JsonCharacterBlock classify(ByteVector chunk0) {
31+
VectorShuffle<Byte> chunk0Low = extractLowNibble(chunk0).toShuffle();
32+
long whitespace = eq(chunk0, WHITESPACE_TABLE.rearrange(chunk0Low));
33+
ByteVector curlified0 = curlify(chunk0);
34+
long op = eq(curlified0, OP_TABLE.rearrange(chunk0Low));
35+
return new JsonCharacterBlock(whitespace, op);
36+
}
2537

2638
JsonCharacterBlock classify(ByteVector chunk0, ByteVector chunk1) {
2739
VectorShuffle<Byte> chunk0Low = extractLowNibble(chunk0).toShuffle();
2840
VectorShuffle<Byte> chunk1Low = extractLowNibble(chunk1).toShuffle();
29-
30-
long whitespace = eq(
31-
chunk0,
32-
WHITESPACE_TABLE.rearrange(chunk0Low),
33-
chunk1,
34-
WHITESPACE_TABLE.rearrange(chunk1Low)
35-
);
36-
41+
long whitespace = eq(chunk0, WHITESPACE_TABLE.rearrange(chunk0Low), chunk1, WHITESPACE_TABLE.rearrange(chunk1Low));
3742
ByteVector curlified0 = curlify(chunk0);
3843
ByteVector curlified1 = curlify(chunk1);
39-
long op = eq(
40-
curlified0,
41-
OP_TABLE.rearrange(chunk0Low),
42-
curlified1,
43-
OP_TABLE.rearrange(chunk1Low)
44-
);
45-
44+
long op = eq(curlified0, OP_TABLE.rearrange(chunk0Low), curlified1, OP_TABLE.rearrange(chunk1Low));
4645
return new JsonCharacterBlock(whitespace, op);
4746
}
4847

@@ -55,9 +54,13 @@ private ByteVector curlify(ByteVector vector) {
5554
return vector.or((byte) 0x20);
5655
}
5756

57+
private long eq(ByteVector chunk0, ByteVector mask0) {
58+
return chunk0.eq(mask0).toLong();
59+
}
60+
5861
private long eq(ByteVector chunk0, ByteVector mask0, ByteVector chunk1, ByteVector mask1) {
59-
long rLo = chunk0.eq(mask0).toLong();
60-
long rHi = chunk1.eq(mask1).toLong();
61-
return rLo | (rHi << 32);
62-
}
62+
long r0 = chunk0.eq(mask0).toLong();
63+
long r1 = chunk1.eq(mask1).toLong();
64+
return r0 | (r1 << 32);
65+
}
6366
}

src/main/java/org/simdjson/JsonStringScanner.java

+19-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@ class JsonStringScanner {
1515
private long prevEscaped = 0;
1616

1717
JsonStringScanner() {
18-
VectorSpecies<Byte> species = ByteVector.SPECIES_256;
19-
this.backslashMask = ByteVector.broadcast(species, (byte) '\\');
20-
this.quoteMask = ByteVector.broadcast(species, (byte) '"');
18+
this.backslashMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '\\');
19+
this.quoteMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '"');
20+
}
21+
22+
JsonStringBlock next(ByteVector chunk0) {
23+
long backslash = eq(chunk0, backslashMask);
24+
long escaped = findEscaped(backslash);
25+
long quote = eq(chunk0, quoteMask) & ~escaped;
26+
long inString = prefixXor(quote) ^ prevInString;
27+
prevInString = inString >> 63;
28+
return new JsonStringBlock(quote, inString);
2129
}
2230

2331
JsonStringBlock next(ByteVector chunk0, ByteVector chunk1) {
@@ -29,10 +37,15 @@ JsonStringBlock next(ByteVector chunk0, ByteVector chunk1) {
2937
return new JsonStringBlock(quote, inString);
3038
}
3139

40+
private long eq(ByteVector chunk0, ByteVector mask) {
41+
long r = chunk0.eq(mask).toLong();
42+
return r;
43+
}
44+
3245
private long eq(ByteVector chunk0, ByteVector chunk1, ByteVector mask) {
33-
long rLo = chunk0.eq(mask).toLong();
34-
long rHi = chunk1.eq(mask).toLong();
35-
return rLo | (rHi << 32);
46+
long r0 = chunk0.eq(mask).toLong();
47+
long r1 = chunk1.eq(mask).toLong();
48+
return r0 | (r1 << 32);
3649
}
3750

3851
private long findEscaped(long backslash) {

src/main/java/org/simdjson/StructuralIndexer.java

+42-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
package org.simdjson;
22

33
import jdk.incubator.vector.ByteVector;
4+
import jdk.incubator.vector.VectorSpecies;
5+
import java.lang.invoke.MethodType;
46

5-
import static jdk.incubator.vector.ByteVector.SPECIES_256;
67
import static jdk.incubator.vector.VectorOperators.UNSIGNED_LE;
78

89
class StructuralIndexer {
910

11+
static final VectorSpecies<Byte> SPECIES;
12+
static final int N_CHUNKS;
13+
14+
static {
15+
SPECIES = ByteVector.SPECIES_PREFERRED;
16+
N_CHUNKS = 64 / SPECIES.vectorByteSize();
17+
if (SPECIES != ByteVector.SPECIES_256 && SPECIES != ByteVector.SPECIES_512) {
18+
throw new IllegalArgumentException("Unsupported vector species: " + SPECIES);
19+
}
20+
}
21+
1022
private final JsonStringScanner stringScanner;
1123
private final CharactersClassifier classifier;
1224
private final BitIndexes bitIndexes;
@@ -22,29 +34,52 @@ class StructuralIndexer {
2234
}
2335

2436
void step(byte[] buffer, int offset, int blockIndex) {
25-
ByteVector chunk0 = ByteVector.fromArray(SPECIES_256, buffer, offset);
26-
ByteVector chunk1 = ByteVector.fromArray(SPECIES_256, buffer, offset + 32);
37+
switch (N_CHUNKS) {
38+
case 1: step1(buffer, offset, blockIndex); break;
39+
case 2: step2(buffer, offset, blockIndex); break;
40+
default: throw new RuntimeException("Unsupported vector width: " + N_CHUNKS * 64);
41+
}
42+
}
2743

44+
private void step1(byte[] buffer, int offset, int blockIndex) {
45+
ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_512, buffer, offset);
46+
JsonStringBlock strings = stringScanner.next(chunk0);
47+
JsonCharacterBlock characters = classifier.classify(chunk0);
48+
long unescaped = lteq(chunk0, (byte) 0x1F);
49+
finishStep(characters, strings, unescaped, blockIndex);
50+
}
51+
52+
private void step2(byte[] buffer, int offset, int blockIndex) {
53+
ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_256, buffer, offset);
54+
ByteVector chunk1 = ByteVector.fromArray(ByteVector.SPECIES_256, buffer, offset + 32);
2855
JsonStringBlock strings = stringScanner.next(chunk0, chunk1);
2956
JsonCharacterBlock characters = classifier.classify(chunk0, chunk1);
57+
long unescaped = lteq(chunk0, chunk1, (byte) 0x1F);
58+
finishStep(characters, strings, unescaped, blockIndex);
59+
}
3060

61+
private void finishStep(JsonCharacterBlock characters, JsonStringBlock strings, long unescaped, int blockIndex) {
3162
long scalar = characters.scalar();
3263
long nonQuoteScalar = scalar & ~strings.quote();
3364
long followsNonQuoteScalar = nonQuoteScalar << 1 | prevScalar;
3465
prevScalar = nonQuoteScalar >>> 63;
35-
long unescaped = lteq(chunk0, chunk1, (byte) 0x1F);
3666
// TODO: utf-8 validation
3767
long potentialScalarStart = scalar & ~followsNonQuoteScalar;
3868
long potentialStructuralStart = characters.op() | potentialScalarStart;
3969
bitIndexes.write(blockIndex, prevStructurals);
4070
prevStructurals = potentialStructuralStart & ~strings.stringTail();
4171
unescapedCharsError |= strings.nonQuoteInsideString(unescaped);
72+
}
73+
74+
private long lteq(ByteVector chunk0, byte scalar) {
75+
long r = chunk0.compare(UNSIGNED_LE, scalar).toLong();
76+
return r;
4277
}
4378

4479
private long lteq(ByteVector chunk0, ByteVector chunk1, byte scalar) {
45-
long rLo = chunk0.compare(UNSIGNED_LE, scalar).toLong();
46-
long rHi = chunk1.compare(UNSIGNED_LE, scalar).toLong();
47-
return rLo | (rHi << 32);
80+
long r0 = chunk0.compare(UNSIGNED_LE, scalar).toLong();
81+
long r1 = chunk1.compare(UNSIGNED_LE, scalar).toLong();
82+
return r0 | (r1 << 32);
4883
}
4984

5085
void finish(int blockIndex) {

src/main/java/org/simdjson/TapeBuilder.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414
import static org.simdjson.Tape.START_OBJECT;
1515
import static org.simdjson.Tape.STRING;
1616
import static org.simdjson.Tape.TRUE_VALUE;
17-
import static jdk.incubator.vector.ByteVector.SPECIES_256;
1817

1918
class TapeBuilder {
2019

2120
private static final byte SPACE = 0x20;
2221
private static final byte BACKSLASH = '\\';
2322
private static final byte QUOTE = '"';
24-
private static final int BYTES_PROCESSED = 32;
23+
private static final int BYTES_PROCESSED = StructuralIndexer.SPECIES.vectorByteSize();
2524
private static final byte[] ESCAPE_MAP = new byte[]{
2625
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0.
2726
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
@@ -198,7 +197,7 @@ private void visitString(byte[] buffer, int idx) {
198197
int src = idx + 1;
199198
int dst = stringBufferIdx + Integer.BYTES;
200199
while (true) {
201-
ByteVector srcVec = ByteVector.fromArray(SPECIES_256, buffer, src);
200+
ByteVector srcVec = ByteVector.fromArray(StructuralIndexer.SPECIES, buffer, src);
202201
srcVec.intoArray(stringBuffer, dst);
203202
long backslashBits = srcVec.eq(BACKSLASH).toLong();
204203
long quoteBits = srcVec.eq(QUOTE).toLong();

src/test/java/org/simdjson/BenchmarkCorrectnessTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ public void numberParserTest(String input, Double expected) {
5959
private static byte[] loadTwitterJson() throws IOException {
6060
try (InputStream is = BenchmarkCorrectnessTest.class.getResourceAsStream("/twitter.json")) {
6161
return is.readAllBytes();
62-
}
62+
}
6363
}
6464
}

src/test/java/org/simdjson/CharactersClassifierTest.java

+13-5
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import static java.nio.charset.StandardCharsets.UTF_8;
66
import static org.assertj.core.api.Assertions.assertThat;
7-
import static org.simdjson.StringUtils.chunk0;
8-
import static org.simdjson.StringUtils.chunk1;
7+
import static org.simdjson.StringUtils.chunk;
98

109
public class CharactersClassifierTest {
1110

@@ -16,7 +15,7 @@ public void classifiesOperators() {
1615
String str = "a{bc}1:2,3[efg]aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
1716

1817
// when
19-
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
18+
JsonCharacterBlock block = classify(classifier, str);
2019

2120
// then
2221
assertThat(block.op()).isEqualTo(0x4552);
@@ -39,7 +38,7 @@ public void classifiesControlCharactersAsOperators() {
3938
}, UTF_8);
4039

4140
// when
42-
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
41+
JsonCharacterBlock block = classify(classifier, str);
4342

4443
// then
4544
assertThat(block.op()).isEqualTo(0x28);
@@ -53,10 +52,19 @@ public void classifiesWhitespaces() {
5352
String str = "a bc\t1\n2\r3efgaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
5453

5554
// when
56-
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
55+
JsonCharacterBlock block = classify(classifier, str);
5756

5857
// then
5958
assertThat(block.whitespace()).isEqualTo(0x152);
6059
assertThat(block.op()).isEqualTo(0);
6160
}
61+
62+
private JsonCharacterBlock classify(CharactersClassifier classifier, String str) {
63+
return switch (StructuralIndexer.N_CHUNKS) {
64+
case 1 -> classifier.classify(chunk(str, 0));
65+
case 2 -> classifier.classify(chunk(str, 0), chunk(str, 1));
66+
default -> throw new RuntimeException("Unsupported chunk count: " + StructuralIndexer.N_CHUNKS);
67+
};
68+
}
69+
6270
}

src/test/java/org/simdjson/JsonStringScannerTest.java

+20-13
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
import org.junit.jupiter.params.provider.ValueSource;
66

77
import static org.assertj.core.api.Assertions.assertThat;
8-
import static org.simdjson.StringUtils.chunk0;
9-
import static org.simdjson.StringUtils.chunk1;
8+
import static org.simdjson.StringUtils.chunk;
109
import static org.simdjson.StringUtils.padWithSpaces;
1110

1211
public class JsonStringScannerTest {
@@ -18,7 +17,7 @@ public void testUnquotedString() {
1817
String str = padWithSpaces("abc 123");
1918

2019
// when
21-
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
20+
JsonStringBlock block = next(stringScanner, str);
2221

2322
// then
2423
assertThat(block.quote()).isEqualTo(0);
@@ -31,7 +30,7 @@ public void testQuotedString() {
3130
String str = padWithSpaces("\"abc 123\"");
3231

3332
// when
34-
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
33+
JsonStringBlock block = next(stringScanner, str);
3534

3635
// then
3736
assertThat(block.quote()).isEqualTo(0x101);
@@ -44,7 +43,7 @@ public void testStartingQuotes() {
4443
String str = padWithSpaces("\"abc 123");
4544

4645
// when
47-
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
46+
JsonStringBlock block = next(stringScanner, str);
4847

4948
// then
5049
assertThat(block.quote()).isEqualTo(0x1);
@@ -58,8 +57,8 @@ public void testQuotedStringSpanningMultipleBlocks() {
5857
String str1 = " c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 d0 d1 d2 d3 d4 d5 d6 d7 d8 d\" def";
5958

6059
// when
61-
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
62-
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
60+
JsonStringBlock firstBlock = next(stringScanner, str0);
61+
JsonStringBlock secondBlock = next(stringScanner, str1);
6362

6463
// then
6564
assertThat(firstBlock.quote()).isEqualTo(0x10);
@@ -77,7 +76,7 @@ public void testEscapedQuote(String str) {
7776
String padded = padWithSpaces(str);
7877

7978
// when
80-
JsonStringBlock block = stringScanner.next(chunk0(padded), chunk1(padded));
79+
JsonStringBlock block = next(stringScanner, padded);
8180

8281
// then
8382
assertThat(block.quote()).isEqualTo(0);
@@ -91,8 +90,8 @@ public void testEscapedQuoteSpanningMultipleBlocks() {
9190
String str1 = padWithSpaces("\"def");
9291

9392
// when
94-
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
95-
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
93+
JsonStringBlock firstBlock = next(stringScanner, str0);
94+
JsonStringBlock secondBlock = next(stringScanner, str1);
9695

9796
// then
9897
assertThat(firstBlock.quote()).isEqualTo(0);
@@ -110,7 +109,7 @@ public void testUnescapedQuote(String str) {
110109
String padded = padWithSpaces(str);
111110

112111
// when
113-
JsonStringBlock block = stringScanner.next(chunk0(padded), chunk1(padded));
112+
JsonStringBlock block = next(stringScanner, padded);
114113

115114
// then
116115
assertThat(block.quote()).isEqualTo(0x1L << str.indexOf('"'));
@@ -124,11 +123,19 @@ public void testUnescapedQuoteSpanningMultipleBlocks() {
124123
String str1 = padWithSpaces("\\\"abc");
125124

126125
// when
127-
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
128-
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
126+
JsonStringBlock firstBlock = next(stringScanner, str0);
127+
JsonStringBlock secondBlock = next(stringScanner, str1);
129128

130129
// then
131130
assertThat(firstBlock.quote()).isEqualTo(0);
132131
assertThat(secondBlock.quote()).isEqualTo(0x2);
133132
}
133+
134+
private JsonStringBlock next(JsonStringScanner scanner, String str) {
135+
return switch (StructuralIndexer.N_CHUNKS) {
136+
case 1 -> scanner.next(chunk(str, 0));
137+
case 2 -> scanner.next(chunk(str, 0), chunk(str, 1));
138+
default -> throw new RuntimeException("Unsupported chunk count: " + StructuralIndexer.N_CHUNKS);
139+
};
140+
}
134141
}

0 commit comments

Comments
 (0)