Skip to content

[WIP] [SPARK-47547] BloomFilter fpp degradation #50933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
7 changes: 7 additions & 0 deletions common/sketch/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit-pioneer</groupId>
<artifactId>junit-pioneer</artifactId>
<version>2.3.0</version>
<scope>test</scope>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ public static BloomFilter create(long expectedNumItems, double fpp) {
* pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter.
*/
public static BloomFilter create(long expectedNumItems, long numBits) {
return create(expectedNumItems, numBits, BloomFilterImpl.DEFAULT_SEED);
}
public static BloomFilter create(long expectedNumItems, long numBits, int seed) {
if (expectedNumItems <= 0) {
throw new IllegalArgumentException("Expected insertions must be positive");
}
Expand All @@ -264,6 +267,6 @@ public static BloomFilter create(long expectedNumItems, long numBits) {
throw new IllegalArgumentException("Number of bits must be positive");
}

return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits);
return new BloomFilterImpl(optimalNumOfHashFunctions(expectedNumItems, numBits), numBits, seed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@

class BloomFilterImpl extends BloomFilter implements Serializable {

public static final int DEFAULT_SEED = 0;

private int seed;
private int numHashFunctions;

private BitArray bits;

BloomFilterImpl(int numHashFunctions, long numBits) {
this(new BitArray(numBits), numHashFunctions);
this(numHashFunctions, numBits, DEFAULT_SEED);
}

BloomFilterImpl(int numHashFunctions, long numBits, int seed) {
this(new BitArray(numBits), numHashFunctions, seed);
}

private BloomFilterImpl(BitArray bits, int numHashFunctions) {
private BloomFilterImpl(BitArray bits, int numHashFunctions, int seed) {
this.bits = bits;
this.numHashFunctions = numHashFunctions;
this.seed = seed;
}

private BloomFilterImpl() {}
Expand Down Expand Up @@ -82,13 +90,16 @@ public boolean putString(String item) {

@Override
public boolean putBinary(byte[] item) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, seed);
int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);

// Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
long combinedHash = (long) h1 * Integer.MAX_VALUE;
for (long i = 0; i < numHashFunctions; i++) {
combinedHash += h2;
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
Expand All @@ -105,12 +116,15 @@ public boolean mightContainString(String item) {

@Override
public boolean mightContainBinary(byte[] item) {
int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, 0);
int h1 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, seed);
int h2 = Murmur3_x86_32.hashUnsafeBytes(item, Platform.BYTE_ARRAY_OFFSET, item.length, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);

// Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
long combinedHash = (long) h1 * Integer.MAX_VALUE;
for (long i = 0; i < numHashFunctions; i++) {
combinedHash += h2;
// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
Expand All @@ -129,13 +143,17 @@ public boolean putLong(long item) {
// Note that `CountMinSketch` use a different strategy, it hash the input long element with
// every i to produce n hash values.
// TODO: the strategy of `CountMinSketch` looks more advanced, should we follow it here?
int h1 = Murmur3_x86_32.hashLong(item, 0);
int h1 = Murmur3_x86_32.hashLong(item, seed);
int h2 = Murmur3_x86_32.hashLong(item, h1);

long bitSize = bits.bitSize();
boolean bitsChanged = false;
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);

// Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
long combinedHash = (long) h1 * Integer.MAX_VALUE;
for (long i = 0; i < numHashFunctions; i++) {
combinedHash += h2;

// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
Expand All @@ -147,12 +165,16 @@ public boolean putLong(long item) {

@Override
public boolean mightContainLong(long item) {
int h1 = Murmur3_x86_32.hashLong(item, 0);
int h1 = Murmur3_x86_32.hashLong(item, seed);
int h2 = Murmur3_x86_32.hashLong(item, h1);

long bitSize = bits.bitSize();
for (int i = 1; i <= numHashFunctions; i++) {
int combinedHash = h1 + (i * h2);

// Integer.MAX_VALUE takes care of scrambling the higher four bytes of combinedHash
long combinedHash = (long) h1 * Integer.MAX_VALUE;
for (long i = 0; i < numHashFunctions; i++) {
combinedHash += h2;

// Flip all the bits if it's negative (guaranteed positive number)
if (combinedHash < 0) {
combinedHash = ~combinedHash;
Expand Down Expand Up @@ -226,6 +248,12 @@ private BloomFilterImpl checkCompatibilityForMerge(BloomFilter other)
throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size");
}

if (this.seed != that.seed) {
throw new IncompatibleMergeException(
"Cannot merge bloom filters with different seeds"
);
}

if (this.numHashFunctions != that.numHashFunctions) {
throw new IncompatibleMergeException(
"Cannot merge bloom filters with different number of hash functions"
Expand All @@ -241,6 +269,7 @@ public void writeTo(OutputStream out) throws IOException {
dos.writeInt(Version.V1.getVersionNumber());
dos.writeInt(numHashFunctions);
bits.writeTo(dos);
dos.writeInt(seed);
}

private void readFrom0(InputStream in) throws IOException {
Expand All @@ -253,6 +282,13 @@ private void readFrom0(InputStream in) throws IOException {

this.numHashFunctions = dis.readInt();
this.bits = BitArray.readFrom(dis);

// compatibility with "seedless" serialization streams.
try {
this.seed = dis.readInt();
} catch (EOFException e) {
this.seed = DEFAULT_SEED;
}
}

public static BloomFilterImpl readFrom(InputStream in) throws IOException {
Expand Down
Loading