|
1 | 1 | package com.github.luben.zstd
|
2 | 2 |
|
3 | 3 | import org.scalatest.flatspec.AnyFlatSpec
|
| 4 | + |
4 | 5 | import java.io._
|
5 | 6 | import java.nio._
|
6 |
| -import java.nio.channels.FileChannel |
7 |
| -import java.nio.channels.FileChannel.MapMode |
8 |
| -import java.nio.file.StandardOpenOption |
9 |
| - |
10 | 7 | import scala.io._
|
11 |
| -import scala.collection.mutable.WrappedArray |
| 8 | +import scala.util.Using |
12 | 9 |
|
13 | 10 | class ZstdDictSpec extends AnyFlatSpec {
|
14 | 11 |
|
15 | 12 | def source = Source.fromFile("src/test/resources/xml")(Codec.ISO8859).map{_.toByte}
|
16 | 13 |
|
17 |
| - def train(legacy: Boolean): Array[Byte] = { |
18 |
| - val src = source.sliding(1024, 1024).take(1024).map(_.toArray) |
19 |
| - val trainer = new ZstdDictTrainer(1024 * 1024, 32 * 1024) |
| 14 | + def train(legacy: Boolean, sampleSize: Int): Array[Byte] = { |
| 15 | + val src = source.sliding(1024, 1024).take(sampleSize).map(_.toArray) |
| 16 | + val trainer = new ZstdDictTrainer(1024 * sampleSize, 32 * sampleSize) |
20 | 17 | for (sample <- src) {
|
21 | 18 | trainer.addSample(sample)
|
22 | 19 | }
|
@@ -52,7 +49,8 @@ class ZstdDictSpec extends AnyFlatSpec {
|
52 | 49 | val levels = List(1)
|
53 | 50 | for {
|
54 | 51 | legacy <- legacyS
|
55 |
| - dict = train(legacy) |
| 52 | + dict = train(legacy, 1024) |
| 53 | + dict2 = train(legacy, 512) |
56 | 54 | dictInDirectByteBuffer = wrapInDirectByteBuffer(dict)
|
57 | 55 | level <- levels
|
58 | 56 | } {
|
@@ -282,6 +280,83 @@ class ZstdDictSpec extends AnyFlatSpec {
|
282 | 280 | assert(input.toSeq == output.toSeq)
|
283 | 281 | }
|
284 | 282 |
|
| 283 | + it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy " in { |
| 284 | + // given: compress using first one dictionary, then another |
| 285 | + val cdict = new ZstdDictCompress(dict, 0, dict.length, 1) |
| 286 | + val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1) |
| 287 | + |
| 288 | + val compressedWithDict1 = compressWithDict(cdict) |
| 289 | + val compressedWithDict2 = compressWithDict(cdict2) |
| 290 | + |
| 291 | + // when: decompress with the both dictionaries configured and multiple dict references enabled |
| 292 | + val ddict = new ZstdDictDecompress(dict) |
| 293 | + val ddict2 = new ZstdDictDecompress(dict2) |
| 294 | + |
| 295 | + val dicts = ddict::ddict2::Nil |
| 296 | + val uncompressed1 = uncompressWithMultipleDicts(compressedWithDict1, dicts) |
| 297 | + val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts) |
| 298 | + |
| 299 | + // then: both compressed inputs decompressed successfully |
| 300 | + assert(uncompressed1.toSeq == input.toSeq) |
| 301 | + assert(Zstd.getDictIdFromFrame(compressedWithDict1) == Zstd.getDictIdFromDict(dict)) |
| 302 | + |
| 303 | + assert(uncompressed2.toSeq == input.toSeq) |
| 304 | + assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2)) |
| 305 | + } |
| 306 | + |
| 307 | + it should s"round-trip streaming compression/decompression with multiple fast dicts with legacy $legacy and disabled multiple dict references" in { |
| 308 | + // given: compress using first one dictionary, then another |
| 309 | + val cdict = new ZstdDictCompress(dict, 0, dict.length, 1) |
| 310 | + val cdict2 = new ZstdDictCompress(dict2, 0, dict2.length, 1) |
| 311 | + |
| 312 | + val compressedWithDict1 = compressWithDict(cdict) |
| 313 | + val compressedWithDict2 = compressWithDict(cdict2) |
| 314 | + |
| 315 | + // when: decompress with the both dictionaries configured and multiple dict references disabled |
| 316 | + // -> should be used only the second one |
| 317 | + val ddict = new ZstdDictDecompress(dict) |
| 318 | + val ddict2 = new ZstdDictDecompress(dict2) |
| 319 | + |
| 320 | + val dicts = ddict :: ddict2 :: Nil |
| 321 | + val uncompressed2 = uncompressWithMultipleDicts(compressedWithDict2, dicts, multipleDdicts = false) |
| 322 | + |
| 323 | + // then: decompression of compressed with the first dict should fail with dictionary mismatch, |
| 324 | + // the second one should be decompressed successfully |
| 325 | + val caughtException = intercept[ZstdIOException] { |
| 326 | + uncompressWithMultipleDicts(compressedWithDict1, dicts, multipleDdicts = false) |
| 327 | + } |
| 328 | + assert(caughtException.getMessage == "Dictionary mismatch") |
| 329 | + |
| 330 | + assert(uncompressed2.toSeq == input.toSeq) |
| 331 | + assert(Zstd.getDictIdFromFrame(compressedWithDict2) == Zstd.getDictIdFromDict(dict2)) |
| 332 | + } |
| 333 | + |
| 334 | + def compressWithDict(cdict: ZstdDictCompress): Array[Byte] = { |
| 335 | + val os = new ByteArrayOutputStream(Zstd.compressBound(input.length.toLong).toInt) |
| 336 | + Using(new ZstdOutputStream(os, 1)) { zos => |
| 337 | + zos.setDict(cdict) |
| 338 | + zos.write(input) |
| 339 | + } |
| 340 | + os.toByteArray |
| 341 | + } |
| 342 | + |
| 343 | + def uncompressWithMultipleDicts( |
| 344 | + compressed: Array[Byte], |
| 345 | + dicts: List[ZstdDictDecompress], |
| 346 | + multipleDdicts: Boolean = true |
| 347 | + ): Array[Byte] = { |
| 348 | + Using.resources( |
| 349 | + new ZstdInputStream(new ByteArrayInputStream(compressed)) |
| 350 | + .setRefMultipleDDicts(multipleDdicts), |
| 351 | + new ByteArrayOutputStream() |
| 352 | + ) { (zis, os) => |
| 353 | + dicts.foreach(zis.setDict) |
| 354 | + |
| 355 | + zis.transferTo(os) |
| 356 | + os.toByteArray |
| 357 | + } |
| 358 | + } |
| 359 | + |
285 | 360 | it should s"round-trip streaming ByteBuffer compression/decompression with byte[] dict with legacy $legacy" in {
|
286 | 361 | val size = input.length
|
287 | 362 | val os = ByteBuffer.allocateDirect(Zstd.compressBound(size.toLong).toInt)
|
|
0 commit comments