Skip to content

Commit d0997db

Browse files
Alex1OPSluben
authored andcommitted
Add support for multiple dictionary references in streaming decompression.
1 parent 100c434 commit d0997db

File tree

5 files changed

+115
-9
lines changed

5 files changed

+115
-9
lines changed

src/main/java/com/github/luben/zstd/Zstd.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ public static long decompressDirectByteBufferFastDict(ByteBuffer dst, int dstOff
567567
public static native int setCompressionWorkers(long stream, int workers);
568568
public static native int setDecompressionLongMax(long stream, int windowLogMax);
569569
public static native int setDecompressionMagicless(long stream, boolean useMagicless);
570+
public static native int setRefMultipleDDicts(long stream, boolean useMultiple);
570571

571572
/* Utility methods */
572573

src/main/java/com/github/luben/zstd/ZstdInputStream.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ public ZstdInputStream setLongMax(int windowLogMax) throws IOException {
8888
return this;
8989
}
9090

91+
/**
92+
* Enable or disable support for multiple dictionary references
93+
*
94+
* @param useMultiple Enables references table for DDict, so the DDict used for decompression will be
95+
* determined per the dictId in the frame, default: false
96+
*/
97+
public ZstdInputStream setRefMultipleDDicts(boolean useMultiple) throws IOException {
98+
inner.setRefMultipleDDicts(useMultiple);
99+
return this;
100+
}
101+
91102
public int read(byte[] dst, int offset, int len) throws IOException {
92103
return inner.read(dst, offset, len);
93104
}

src/main/java/com/github/luben/zstd/ZstdInputStreamNoFinalizer.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ public synchronized ZstdInputStreamNoFinalizer setLongMax(int windowLogMax) thro
117117
return this;
118118
}
119119

120+
public synchronized ZstdInputStreamNoFinalizer setRefMultipleDDicts(boolean useMultiple) throws IOException {
121+
int size = Zstd.setRefMultipleDDicts(stream, useMultiple);
122+
if (Zstd.isError(size)) {
123+
throw new ZstdIOException(size);
124+
}
125+
return this;
126+
}
127+
120128
public synchronized int read(byte[] dst, int offset, int len) throws IOException {
121129
// guard agains buffer overflows
122130
if (offset < 0 || len > dst.length - offset) {

src/main/native/jni_zstd.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,17 @@ JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setCompressionWorkers
332332
return ZSTD_CCtx_setParameter((ZSTD_CCtx *)(intptr_t) stream, ZSTD_c_nbWorkers, workers);
333333
}
334334

335+
/*
336+
* Class: com_github_luben_zstd_Zstd
337+
* Method: setRefMultipleDDicts
338+
* Signature: (J)I
339+
*/
340+
JNIEXPORT jint JNICALL Java_com_github_luben_zstd_Zstd_setRefMultipleDDicts
341+
(JNIEnv *env, jclass obj, jlong stream, jboolean enabled) {
342+
ZSTD_refMultipleDDicts_e value = enabled ? ZSTD_rmd_refMultipleDDicts : ZSTD_rmd_refSingleDDict;
343+
return ZSTD_DCtx_setParameter((ZSTD_DCtx *)(intptr_t) stream, ZSTD_d_refMultipleDDicts, value);
344+
}
345+
335346
/*
336347
* Class: com_github_luben_zstd_Zstd
337348
* Methods: header constants access

src/test/scala/ZstdDict.scala

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
package com.github.luben.zstd
22

33
import org.scalatest.flatspec.AnyFlatSpec
4+
45
import java.io._
56
import java.nio._
6-
import java.nio.channels.FileChannel
7-
import java.nio.channels.FileChannel.MapMode
8-
import java.nio.file.StandardOpenOption
9-
107
import scala.io._
11-
import scala.collection.mutable.WrappedArray
8+
import scala.util.Using
129

1310
class ZstdDictSpec extends AnyFlatSpec {
1411

1512
def source = Source.fromFile("src/test/resources/xml")(Codec.ISO8859).map{_.toByte}
1613

17-
def train(legacy: Boolean): Array[Byte] = {
18-
val src = source.sliding(1024, 1024).take(1024).map(_.toArray)
19-
val trainer = new ZstdDictTrainer(1024 * 1024, 32 * 1024)
14+
def train(legacy: Boolean, sampleSize: Int): Array[Byte] = {
15+
val src = source.sliding(1024, 1024).take(sampleSize).map(_.toArray)
16+
val trainer = new ZstdDictTrainer(1024 * sampleSize, 32 * sampleSize)
2017
for (sample <- src) {
2118
trainer.addSample(sample)
2219
}
@@ -52,7 +49,8 @@ class ZstdDictSpec extends AnyFlatSpec {
5249
val levels = List(1)
5350
for {
5451
legacy <- legacyS
55-
dict = train(legacy)
52+
dict = train(legacy, 1024)
53+
dict2 = train(legacy, 512)
5654
dictInDirectByteBuffer = wrapInDirectByteBuffer(dict)
5755
level <- levels
5856
} {
@@ -282,6 +280,83 @@ class ZstdDictSpec extends AnyFlatSpec {
282280
assert(input.toSeq == output.toSeq)
283281
}
284282

283+
it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy " in {
284+
// given: compress using first one dictionary, then another
285+
val cdict = new ZstdDictCompress(dict, 0, dict.length, 1)
286+
val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1)
287+
288+
val compressedWithDict1 = compressWithDict(cdict)
289+
val compressedWithDict2 = compressWithDict(cdict2)
290+
291+
// when: decompress with the both dictionaries configured and multiple dict references enabled
292+
val ddict = new ZstdDictDecompress(dict)
293+
val ddict2 = new ZstdDictDecompress(dict2)
294+
295+
val dicts = ddict::ddict2::Nil
296+
val uncompressed1 = uncompressWithMultipleDicts(compressedWithDict1, dicts)
297+
val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts)
298+
299+
// then: both compressed inputs decompressed successfully
300+
assert(uncompressed1.toSeq == input.toSeq)
301+
assert(Zstd.getDictIdFromFrame(compressedWithDict1) == Zstd.getDictIdFromDict(dict))
302+
303+
assert(uncompressed2.toSeq == input.toSeq)
304+
assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2))
305+
}
306+
307+
it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy and disabled multiple dict references" in {
308+
// given: compress using first one dictionary, then another
309+
val cdict = new ZstdDictCompress(dict, 0, dict.length, 1)
310+
val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1)
311+
312+
val compressedWithDict1 = compressWithDict(cdict)
313+
val compressedWithDict2 = compressWithDict(cdict2)
314+
315+
// when: decompress with the both dictionaries configured and multiple dict references disabled
316+
// -> should be used only the second one
317+
val ddict = new ZstdDictDecompress(dict)
318+
val ddict2 = new ZstdDictDecompress(dict2)
319+
320+
val dicts = ddict :: ddict2 :: Nil
321+
val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts, multipleDdicts = false)
322+
323+
// then: decompression of compressed with the first dict should fail with dictionary mismatch,
324+
// the second one should be decompressed successfully
325+
val caughtException = intercept[ZstdIOException] {
326+
uncompressWithMultipleDicts(compressedWithDict1, dicts, multipleDdicts = false)
327+
}
328+
assert(caughtException.getMessage == "Dictionary mismatch")
329+
330+
assert(uncompressed2.toSeq == input.toSeq)
331+
assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2))
332+
}
333+
334+
def compressWithDict(cdict: ZstdDictCompress): Array[Byte] = {
335+
val os = new ByteArrayOutputStream(Zstd.compressBound(input.length.toLong).toInt)
336+
Using(new ZstdOutputStream(os, 1)) { zos =>
337+
zos.setDict(cdict)
338+
zos.write(input)
339+
}
340+
os.toByteArray
341+
}
342+
343+
def uncompressWithMultipleDicts(
344+
compressed: Array[Byte],
345+
dicts: List[ZstdDictDecompress],
346+
multipleDdicts: Boolean = true
347+
): Array[Byte] = {
348+
Using.resources(
349+
new ZstdInputStream(new ByteArrayInputStream(compressed))
350+
.setRefMultipleDDicts(multipleDdicts),
351+
new ByteArrayOutputStream()
352+
) { (zis, os) =>
353+
dicts.foreach(zis.setDict)
354+
355+
zis.transferTo(os)
356+
os.toByteArray
357+
}
358+
}
359+
285360
it should s"round-trip streaming ByteBuffer compression/decompression with byte[] dict with legacy $legacy" in {
286361
val size = input.length
287362
val os = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt)

0 commit comments

Comments
 (0)