Skip to content

UnsafeBufferOperations.forEachSegment implementation #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/api/kotlinx-io-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ public abstract interface class kotlinx/io/unsafe/SegmentWriteContext {

public final class kotlinx/io/unsafe/UnsafeBufferOperations {
public static final field INSTANCE Lkotlinx/io/unsafe/UnsafeBufferOperations;
public final fun forEachSegment (Lkotlinx/io/Buffer;JLkotlin/jvm/functions/Function3;)V
public final fun forEachSegment (Lkotlinx/io/Buffer;Lkotlin/jvm/functions/Function2;)V
public final fun getMaxSafeWriteCapacity ()I
public final fun iterate (Lkotlinx/io/Buffer;JLkotlin/jvm/functions/Function3;)V
public final fun iterate (Lkotlinx/io/Buffer;Lkotlin/jvm/functions/Function2;)V
Expand Down
2 changes: 2 additions & 0 deletions core/api/kotlinx-io-core.klib.api
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ final object kotlinx.io.unsafe/UnsafeBufferOperations { // kotlinx.io.unsafe/Uns
final fun <get-maxSafeWriteCapacity>(): kotlin/Int // kotlinx.io.unsafe/UnsafeBufferOperations.maxSafeWriteCapacity.<get-maxSafeWriteCapacity>|<get-maxSafeWriteCapacity>(){}[0]

final fun moveToTail(kotlinx.io/Buffer, kotlin/ByteArray, kotlin/Int = ..., kotlin/Int = ...) // kotlinx.io.unsafe/UnsafeBufferOperations.moveToTail|moveToTail(kotlinx.io.Buffer;kotlin.ByteArray;kotlin.Int;kotlin.Int){}[0]
final inline fun forEachSegment(kotlinx.io/Buffer, kotlin/Function2<kotlinx.io.unsafe/SegmentReadContext, kotlinx.io/Segment, kotlin/Unit>) // kotlinx.io.unsafe/UnsafeBufferOperations.forEachSegment|forEachSegment(kotlinx.io.Buffer;kotlin.Function2<kotlinx.io.unsafe.SegmentReadContext,kotlinx.io.Segment,kotlin.Unit>){}[0]
final inline fun forEachSegment(kotlinx.io/Buffer, kotlin/Long, kotlin/Function3<kotlinx.io.unsafe/SegmentReadContext, kotlinx.io/Segment, kotlin/Long, kotlin/Unit>) // kotlinx.io.unsafe/UnsafeBufferOperations.forEachSegment|forEachSegment(kotlinx.io.Buffer;kotlin.Long;kotlin.Function3<kotlinx.io.unsafe.SegmentReadContext,kotlinx.io.Segment,kotlin.Long,kotlin.Unit>){}[0]
final inline fun iterate(kotlinx.io/Buffer, kotlin/Function2<kotlinx.io.unsafe/BufferIterationContext, kotlinx.io/Segment?, kotlin/Unit>) // kotlinx.io.unsafe/UnsafeBufferOperations.iterate|iterate(kotlinx.io.Buffer;kotlin.Function2<kotlinx.io.unsafe.BufferIterationContext,kotlinx.io.Segment?,kotlin.Unit>){}[0]
final inline fun iterate(kotlinx.io/Buffer, kotlin/Long, kotlin/Function3<kotlinx.io.unsafe/BufferIterationContext, kotlinx.io/Segment?, kotlin/Long, kotlin/Unit>) // kotlinx.io.unsafe/UnsafeBufferOperations.iterate|iterate(kotlinx.io.Buffer;kotlin.Long;kotlin.Function3<kotlinx.io.unsafe.BufferIterationContext,kotlinx.io.Segment?,kotlin.Long,kotlin.Unit>){}[0]
final inline fun readFromHead(kotlinx.io/Buffer, kotlin/Function2<kotlinx.io.unsafe/SegmentReadContext, kotlinx.io/Segment, kotlin/Int>): kotlin/Int // kotlinx.io.unsafe/UnsafeBufferOperations.readFromHead|readFromHead(kotlinx.io.Buffer;kotlin.Function2<kotlinx.io.unsafe.SegmentReadContext,kotlinx.io.Segment,kotlin.Int>){}[0]
Expand Down
19 changes: 7 additions & 12 deletions core/apple/src/BuffersApple.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,14 @@ internal fun Buffer.snapshotAsNSData(): NSData {
val bytes = malloc(size.convert())?.reinterpret<uint8_tVar>()
?: throw Error("malloc failed: ${strerror(errno)?.toKString()}")

UnsafeBufferOperations.iterate(this) { ctx, head ->
var curr: Segment? = head
var index = 0
while (curr != null) {
val segment: Segment = curr
ctx.withData(segment) { data, pos, limit ->
val length = limit - pos
data.usePinned {
memcpy(bytes + index, it.addressOf(pos), length.convert())
}
index += length
var index = 0
UnsafeBufferOperations.forEachSegment(this) { ctx, segment ->
ctx.withData(segment) { data, pos, limit ->
val length = limit - pos
data.usePinned {
memcpy(bytes + index, it.addressOf(pos), length.convert())
}
curr = ctx.next(segment)
index += length
}
}
return NSData.create(bytesNoCopy = bytes, length = size.convert())
Expand Down
25 changes: 10 additions & 15 deletions core/common/src/Buffer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -553,21 +553,16 @@ public class Buffer : Source, Sink {
val len = minOf(maxPrintableBytes, size).toInt()

val builder = StringBuilder(len * 2 + if (size > maxPrintableBytes) 1 else 0)

UnsafeBufferOperations.iterate(this) { ctx, head ->
var bytesWritten = 0
var seg: Segment? = head
do {
seg!!
var idx = 0
while (bytesWritten < len && idx < seg.size) {
val b = ctx.getUnchecked(seg, idx++)
bytesWritten++
builder.append(HEX_DIGIT_CHARS[(b shr 4) and 0xf])
.append(HEX_DIGIT_CHARS[b and 0xf])
}
seg = ctx.next(seg)
} while (seg != null)
var bytesWritten = 0
UnsafeBufferOperations.forEachSegment(this) { ctx, segment ->
var idx = 0
while (bytesWritten < len && idx < segment.size) {
val b = ctx.getUnchecked(segment, idx++)
bytesWritten++
builder
.append(HEX_DIGIT_CHARS[(b shr 4) and 0xf])
.append(HEX_DIGIT_CHARS[b and 0xf])
}
}

if (size > maxPrintableBytes) {
Expand Down
8 changes: 2 additions & 6 deletions core/common/src/Buffers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@ public fun Buffer.snapshot(): ByteString {
check(size <= Int.MAX_VALUE) { "Buffer is too long ($size) to be converted into a byte string." }

return buildByteString(size.toInt()) {
UnsafeBufferOperations.iterate(this@snapshot) { ctx, head ->
var curr = head
while (curr != null) {
ctx.withData(curr, this::append)
curr = ctx.next(curr)
}
UnsafeBufferOperations.forEachSegment(this@snapshot) { ctx, segment ->
ctx.withData(segment, this::append)
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions core/common/src/Utf8.kt
Original file line number Diff line number Diff line change
Expand Up @@ -607,17 +607,17 @@ private fun Buffer.commonReadUtf8(byteCount: Long): String {
// Invariant: byteCount was request()'ed into this buffer beforehand
if (byteCount == 0L) return ""

UnsafeBufferOperations.iterate(this) { ctx, head ->
head!!
if (head.size >= byteCount) {
UnsafeBufferOperations.forEachSegment(this) { ctx, segment ->
if (segment.size >= byteCount) {
var result = ""
ctx.withData(head) { data, pos, limit ->
ctx.withData(segment) { data, pos, limit ->
result = data.commonToUtf8String(pos, min(limit, pos + byteCount.toInt()))
skip(byteCount)
return result
}
}
}
// If the string spans multiple segments, delegate to readBytes()
return readByteArray(byteCount.toInt()).commonToUtf8String()
}
error("Unreacheable")
}
67 changes: 67 additions & 0 deletions core/common/src/unsafe/UnsafeBufferOperations.kt
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,73 @@ public object UnsafeBufferOperations {
iterationAction(BufferIterationContextImpl, buffer.head)
}


/**
* Iterates over [buffer] segments starting from the head.
*
* [action] is invoked with an instance of [SegmentReadContext]
* allowing to read and write in an unchecked manner from [buffer]'s segments
*
* It is considered an error to use a [SegmentReadContext] or a [Segment] instances outside the scope of
* the [action].
*
* Both [action] arguments are valid only within [action] scope, it is an error to store and reuse it later.
* The action might never be invoked if the given [buffer] is empty.
*
* @param buffer a buffer to iterate over
* @param action a callback to invoke with the head reference and an iteration context instance
* @sample kotlinx.io.samples.unsafe.UnsafeReadWriteSamplesJvm.messageDigest2
* @sample kotlinx.io.samples.unsafe.UnsafeBufferOperationsSamples.crc32Unsafe2
*/
public inline fun forEachSegment(
buffer: Buffer,
action: (context: SegmentReadContext, segment: Segment) -> Unit
) {
var curr: Segment? = buffer.head
while (curr != null) {
action(SegmentReadContextImpl, curr)
curr = curr.next
}
}

/**
* Iterates over [buffer] segments starting from a segment spanning over a specified [offset].
*
* [action] is invoked with an instance of [SegmentReadContext]
* allowing to read and write in an unchecked manner from [buffer]'s segments
*
* It is considered an error to use a [SegmentReadContext] or a [Segment] instances outside the scope of
* the [action].
*
* Both [action] arguments are valid only within [action] scope, it is an error to store and reuse it later.
* The action might never be invoked if the given [buffer] is empty.
*
* To locate [buffer]'s [offset]'th byte within the supplied segment, one has to subtract [offset] from the supplied
* offset value for the first segment.
*
* @param buffer a buffer to iterate over
* @param action a callback to invoke with the head reference and an iteration context instance
* @throws IllegalArgumentException when [offset] is negative
* @throws IndexOutOfBoundsException when [offset] is greater or equal to [Buffer.size]
*/
public inline fun forEachSegment(
buffer: Buffer, offset: Long,
action: (context: SegmentReadContext, segment: Segment, startOfTheSegmentOffset: Long) -> Unit
) {
require(offset >= 0) { "Offset must be non-negative: $offset" }
if (offset >= buffer.size) {
throw IndexOutOfBoundsException("Offset should be less than buffer's size (${buffer.size}): $offset")
}

buffer.seek(offset) { segment, o ->
var curr: Segment? = segment
while (curr != null) {
action(SegmentReadContextImpl, curr, if (curr === segment) o else 0L)
curr = curr.next
}
}
}

/**
* Provides access to [buffer] segments starting from a segment spanning over a specified [offset].
*
Expand Down
38 changes: 38 additions & 0 deletions core/common/test/samples/unsafe/unsafeSamples.kt
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,42 @@ class UnsafeBufferOperationsSamples {

assertEquals(0x9896d398U, buffer.crc32UsingGetUnchecked())
}

@OptIn(ExperimentalUnsignedTypes::class)
@Test
fun crc32GetUnchecked2() {
fun generateCrc32Table(): UIntArray {
val table = UIntArray(256)
for (idx in table.indices) {
table[idx] = idx.toUInt()
for (bit in 8 downTo 1) {
table[idx] = if (table[idx] % 2U == 0U) {
table[idx].shr(1)
} else {
table[idx].shr(1).xor(0xEDB88320U)
}
}
}
return table
}
val crc32Table = generateCrc32Table()

@OptIn(UnsafeIoApi::class)
fun Buffer.crc32UsingGetUnchecked(): UInt {
var crc32 = 0xffffffffU
// iterate over all segments
UnsafeBufferOperations.forEachSegment(this) { ctx, segment ->
// Get data from a segment
for (offset in 0..<segment.size) {
val index = ctx.getUnchecked(segment, offset).xor(crc32.toByte()).toUByte()
crc32 = crc32Table[index.toInt()].xor(crc32.shr(8))
}
}
return crc32.xor(0xffffffffU)
}

val buffer = Buffer().also { it.writeString("hello crc32") }

assertEquals(0x9896d398U, buffer.crc32UsingGetUnchecked())
}
}
25 changes: 14 additions & 11 deletions core/jvm/src/BuffersJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,21 @@ public fun Buffer.copyTo(

var remainingByteCount = endIndex - startIndex

UnsafeBufferOperations.iterate(this, startIndex) { ctx, seg, segOffset ->
var curr = seg!!
var currentOffset = (startIndex - segOffset).toInt()
while (remainingByteCount > 0) {
ctx.withData(curr) { data, pos, limit ->
val toCopy = minOf(limit - pos - currentOffset, remainingByteCount).toInt()
out.write(data, pos + currentOffset, toCopy)
remainingByteCount -= toCopy
}
curr = ctx.next(curr) ?: break
currentOffset = 0
var firstSegmentHandled = false
Copy link
Member Author

@qwwdfsad qwwdfsad Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is unfortunate, please take a closer look.
It also comes with a few more ifs this time, not sure it's a measurable effect though

UnsafeBufferOperations.forEachSegment(this, startIndex) { ctx, segment, offset ->
val currentOffset = if (firstSegmentHandled) {
0
} else {
firstSegmentHandled = true
(startIndex - offset).toInt()
}
ctx.withData(segment) { data, pos, limit ->
val toCopy = minOf(limit - pos - currentOffset, remainingByteCount).toInt()
out.write(data, pos + currentOffset, toCopy)
remainingByteCount -= toCopy
}
// TODO this if is untested
if (remainingByteCount <= 0) return
}
}

Expand Down
19 changes: 19 additions & 0 deletions core/jvm/test/samples/unsafeAccessSamplesJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,23 @@ class UnsafeReadWriteSamplesJvm {
val buffer = Buffer().also { it.writeString("hello world") }
assertEquals("5eb63bbbe01eeed093cb22bb8f5acdc3", buffer.digest("MD5").toHexString())
}

@Test
@OptIn(UnsafeByteStringApi::class, ExperimentalStdlibApi::class)
fun messageDigest2() {
fun Buffer.digest(algorithm: String): ByteString {
val md = MessageDigest.getInstance(algorithm)
// iterate over all segment and update data
UnsafeBufferOperations.forEachSegment(this) { ctx, segment ->
ctx.withData(segment) { data, startIndex, endIndex ->
md.update(data, startIndex, endIndex - startIndex)
}
}

return UnsafeByteStringOperations.wrapUnsafe(md.digest())
}

val buffer = Buffer().also { it.writeString("hello world") }
assertEquals("5eb63bbbe01eeed093cb22bb8f5acdc3", buffer.digest("MD5").toHexString())
}
}