Skip to content

Commit

Permalink
Make Inflater and Deflater symmetric (#1426)
Browse files Browse the repository at this point in the history
I expect this to simplify adding DeflaterSink and InflaterSource.
  • Loading branch information
squarejesse authored Feb 7, 2024
1 parent 1d5f262 commit 260710e
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 97 deletions.
64 changes: 64 additions & 0 deletions okio/src/nativeMain/kotlin/okio/DataProcessor.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio

private val emptyByteArray = byteArrayOf()

/**
* Transform a stream of source bytes into a stream of target bytes, one segment at a time. The
* relationship between input byte count and output byte count is arbitrary: a sequence of input
* bytes may produce zero output bytes, or many segments of output bytes.
*
* To use:
*
* 1. Create an instance.
*
* 2. Populate [source] with input data. Set [sourcePos] and [sourceLimit] to a readable slice of
* this array.
*
* 3. Populate [target] with a destination for output data. Set [targetPos] and [targetLimit] to a
* writable slice of this array.
*
* 4. Call [process] to read input data from [source] and write output to [target]. This function
* advances [sourcePos] if input data was read and [targetPos] if compressed output was written.
* If the input array is exhausted (`sourcePos == sourceLimit`) or the output array is full
* (`targetPos == targetLimit`), make an adjustment and call [process] again.
*
* 5. Repeat steps 2 through 4 until the input data is completely exhausted.
*
* 6. Close the processor.
*
* See also, the [zlib manual](https://www.zlib.net/manual.html).
*/
internal abstract class DataProcessor : Closeable {
var source: ByteArray = emptyByteArray
var sourcePos: Int = 0
var sourceLimit: Int = 0

var target: ByteArray = emptyByteArray
var targetPos: Int = 0
var targetLimit: Int = 0

var closed: Boolean = false
protected set

/**
* Returns true if no further calls to [process] are required to complete the operation.
* Otherwise, make space available in [target] and call [process] again.
*/
@Throws(ProtocolException::class)
abstract fun process(): Boolean
}
59 changes: 10 additions & 49 deletions okio/src/nativeMain/kotlin/okio/Deflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,16 @@ import platform.zlib.deflateEnd
import platform.zlib.deflateInit2
import platform.zlib.z_stream_s

internal val emptyByteArray = byteArrayOf()

/**
* Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits
* the zlib header and trailer, and does not compute a check value.
*
* To use:
*
* 1. Create an instance.
*
* 2. Populate [source] with uncompressed data. Set [sourcePos] and [sourceLimit] to a readable
* slice of this array.
*
* 3. Populate [target] with a destination for compressed data. Set [targetPos] and [targetLimit] to
* a writable slice of this array.
*
* 4. Call [deflate] to read input data from [source] and write compressed output to [target]. This
* function advances [sourcePos] if input data was read and [targetPos] if compressed output was
* written. If the input array is exhausted (`sourcePos == sourceLimit`) or the output array is
* full (`targetPos == targetLimit`), make an adjustment and call [deflate] again.
*
* 5. Repeat steps 2 through 4 until the input data is completely exhausted. Set [sourceFinished]
* to true before the last call to [deflate]. (It is okay to call deflate() when the source is
* exhausted.)
*
* 6. Close the Deflater.
* Note that you must set [flush] to [Z_FINISH] before the last call to [process]. (It is okay to
* call process() when the source is exhausted.)
*
* See also, the [zlib manual](https://www.zlib.net/manual.html).
*/
internal class Deflater : Closeable {
internal class Deflater : DataProcessor() {
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> {
zalloc = null
zfree = null
Expand All @@ -83,22 +63,10 @@ internal class Deflater : Closeable {
)
}

var source: ByteArray = emptyByteArray
var sourcePos: Int = 0
var sourceLimit: Int = 0
var sourceFinished = false

var target: ByteArray = emptyByteArray
var targetPos: Int = 0
var targetLimit: Int = 0

private var closed = false
/** Probably [Z_NO_FLUSH], [Z_FINISH], or [Z_SYNC_FLUSH]. */
var flush: Int = Z_NO_FLUSH

/**
* Returns true if no further calls to [deflate] are required to complete the operation.
* Otherwise, make space available in [target] and call [deflate] again with the same arguments.
*/
fun deflate(flush: Boolean = false): Boolean {
override fun process(): Boolean {
check(!closed) { "closed" }
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size)
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size)
Expand All @@ -119,23 +87,16 @@ internal class Deflater : Closeable {
}
zStream.avail_out = targetByteCount.toUInt()

val deflateFlush = when {
sourceFinished -> Z_FINISH
flush -> Z_SYNC_FLUSH
else -> Z_NO_FLUSH
}

// One of Z_OK, Z_STREAM_END, Z_STREAM_ERROR, or Z_BUF_ERROR.
val deflateResult = deflate(zStream.ptr, deflateFlush)
val deflateResult = deflate(zStream.ptr, flush)
check(deflateResult != Z_STREAM_ERROR)

sourcePos += sourceByteCount - zStream.avail_in.toInt()
targetPos += targetByteCount - zStream.avail_out.toInt()

return when {
sourceFinished -> deflateResult == Z_STREAM_END
flush -> targetPos < targetLimit
else -> true
return when (deflateResult) {
Z_STREAM_END -> true
else -> targetPos < targetLimit
}
}
}
Expand Down
37 changes: 15 additions & 22 deletions okio/src/nativeMain/kotlin/okio/Inflater.kt
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,8 @@ import platform.zlib.z_stream_s

/**
* Inflate using Kotlin/Native's built-in zlib bindings.
*
* The API is symmetric with [Deflater].
*/
internal class Inflater : Closeable {
internal class Inflater : DataProcessor() {
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> {
zalloc = null
zfree = null
Expand All @@ -50,22 +48,11 @@ internal class Inflater : Closeable {
)
}

var source: ByteArray = emptyByteArray
var sourcePos: Int = 0
var sourceLimit: Int = 0

var target: ByteArray = emptyByteArray
var targetPos: Int = 0
var targetLimit: Int = 0

private var closed = false
var sourceFinished: Boolean = false
private set

/**
* Returns true if no further calls to [inflate] are required because the source stream is
* finished. Otherwise, ensure there's input data in [source] and output space in [target] and
* call this again.
*/
fun inflate(): Boolean {
@Throws(ProtocolException::class)
override fun process(): Boolean {
check(!closed) { "closed" }
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size)
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size)
Expand All @@ -91,10 +78,16 @@ internal class Inflater : Closeable {
sourcePos += sourceByteCount - zStream.avail_in.toInt()
targetPos += targetByteCount - zStream.avail_out.toInt()

return when (inflateResult) {
Z_OK -> false
Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target.
Z_STREAM_END -> true
when (inflateResult) {
Z_OK, Z_BUF_ERROR -> {
return targetPos < targetLimit
}

Z_STREAM_END -> {
sourceFinished = true
return true
}

Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR")

// One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR.
Expand Down
36 changes: 20 additions & 16 deletions okio/src/nativeTest/kotlin/okio/DeflaterTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import kotlin.test.assertTrue
import okio.ByteString.Companion.decodeBase64
import okio.ByteString.Companion.encodeUtf8
import okio.ByteString.Companion.toByteString
import platform.zlib.Z_FINISH
import platform.zlib.Z_NO_FLUSH
import platform.zlib.Z_SYNC_FLUSH

class DeflaterTest {
@Test
Expand All @@ -31,14 +34,14 @@ class DeflaterTest {
source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray()
sourcePos = 0
sourceLimit = source.size
sourceFinished = true
flush = Z_FINISH

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(deflater.deflate())
assertTrue(deflater.process())
assertEquals(deflater.sourceLimit, deflater.sourcePos)
val deflated = deflater.target.toByteString(0, deflater.targetPos)

Expand All @@ -62,15 +65,14 @@ class DeflaterTest {
deflater.source = "God help us, we're in the hands".encodeUtf8().toByteArray()
deflater.sourcePos = 0
deflater.sourceLimit = deflater.source.size
deflater.sourceFinished = false
assertTrue(deflater.deflate())
assertTrue(deflater.process())
assertEquals(deflater.sourceLimit, deflater.sourcePos)

deflater.source = " of engineers.".encodeUtf8().toByteArray()
deflater.sourcePos = 0
deflater.sourceLimit = deflater.source.size
deflater.sourceFinished = true
assertTrue(deflater.deflate())
deflater.flush = Z_FINISH
assertTrue(deflater.process())
assertEquals(deflater.sourceLimit, deflater.sourcePos)

val deflated = deflater.target.toByteString(0, deflater.targetPos)
Expand All @@ -97,19 +99,21 @@ class DeflaterTest {
deflater.target = ByteArray(10)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertFalse(deflater.deflate(flush = true))
deflater.flush = Z_SYNC_FLUSH
assertFalse(deflater.process())
assertEquals(deflater.targetLimit, deflater.targetPos)
targetBuffer.write(deflater.target)

deflater.target = ByteArray(256)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertTrue(deflater.deflate())
deflater.flush = Z_NO_FLUSH
assertTrue(deflater.process())
assertEquals(deflater.sourcePos, deflater.sourceLimit)
targetBuffer.write(deflater.target, 0, deflater.targetPos)

deflater.sourceFinished = true
assertTrue(deflater.deflate())
deflater.flush = Z_FINISH
assertTrue(deflater.process())

// Golden compressed output.
assertEquals(
Expand All @@ -128,20 +132,20 @@ class DeflaterTest {
source = "God help us, we're in the hands of engineers.".encodeUtf8().toByteArray()
sourcePos = 0
sourceLimit = source.size
sourceFinished = true
flush = Z_FINISH
}

deflater.target = ByteArray(10)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertFalse(deflater.deflate())
assertFalse(deflater.process())
assertEquals(deflater.targetLimit, deflater.targetPos)
targetBuffer.write(deflater.target)

deflater.target = ByteArray(256)
deflater.targetPos = 0
deflater.targetLimit = deflater.target.size
assertTrue(deflater.deflate())
assertTrue(deflater.process())
assertEquals(deflater.sourcePos, deflater.sourceLimit)
targetBuffer.write(deflater.target, 0, deflater.targetPos)

Expand All @@ -157,14 +161,14 @@ class DeflaterTest {
@Test
fun deflateEmptySource() {
val deflater = Deflater().apply {
sourceFinished = true
flush = Z_FINISH

target = ByteArray(256)
targetPos = 0
targetLimit = target.size
}

assertTrue(deflater.deflate())
assertTrue(deflater.process())
val deflated = deflater.target.toByteString(0, deflater.targetPos)

// Golden compressed output.
Expand All @@ -182,7 +186,7 @@ class DeflaterTest {
deflater.close()

assertFailsWith<IllegalStateException> {
deflater.deflate()
deflater.process()
}
}

Expand Down
Loading

0 comments on commit 260710e

Please sign in to comment.