Skip to content

Add readArrowReader method to allow loading a dataframe from an ArrowReader #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.kotlinx.dataframe.io

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.ipc.ArrowReader
import org.apache.commons.compress.utils.SeekableInMemoryByteChannel
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataFrame
Expand Down Expand Up @@ -170,3 +171,18 @@ public fun DataFrame.Companion.readArrowFeather(
} else {
readArrowFeather(File(path), nullability)
}

/**
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader]
*/
public fun DataFrame.Companion.readArrow(
reader: ArrowReader,
nullability: NullabilityOptions = NullabilityOptions.Infer
): AnyFrame = readArrowImpl(reader, nullability)

/**
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [ArrowReader]
*/
public fun ArrowReader.toDataFrame(
nullability: NullabilityOptions = NullabilityOptions.Infer
): AnyFrame = DataFrame.Companion.readArrowImpl(this, nullability)
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.arrow.vector.VarCharVector
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.complex.StructVector
import org.apache.arrow.vector.ipc.ArrowFileReader
import org.apache.arrow.vector.ipc.ArrowReader
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.util.DateUtility
Expand Down Expand Up @@ -262,17 +263,7 @@ internal fun DataFrame.Companion.readArrowIPCImpl(
allocator: RootAllocator = Allocator.ROOT,
nullability: NullabilityOptions = NullabilityOptions.Infer,
): AnyFrame {
ArrowStreamReader(channel, allocator).use { reader ->
val flattened = buildList {
val root = reader.vectorSchemaRoot
val schema = root.schema
while (reader.loadNextBatch()) {
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
add(df)
}
}
return flattened.concatKeepingSchema()
}
return readArrowImpl(ArrowStreamReader(channel, allocator), nullability)
}

/**
Expand All @@ -283,14 +274,36 @@ internal fun DataFrame.Companion.readArrowFeatherImpl(
allocator: RootAllocator = Allocator.ROOT,
nullability: NullabilityOptions = NullabilityOptions.Infer,
): AnyFrame {
ArrowFileReader(channel, allocator).use { reader ->
return readArrowImpl(ArrowFileReader(channel, allocator), nullability)
}

/**
* Read [Arrow any format](https://arrow.apache.org/docs/java/ipc.html#reading-writing-ipc-formats) data from existing [reader]
*/
internal fun DataFrame.Companion.readArrowImpl(
reader: ArrowReader,
nullability: NullabilityOptions = NullabilityOptions.Infer
): AnyFrame {
reader.use {
val flattened = buildList {
reader.recordBlocks.forEach { block ->
reader.loadRecordBatch(block)
val root = reader.vectorSchemaRoot
val schema = root.schema
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
add(df)
when (reader) {
is ArrowFileReader -> {
reader.recordBlocks.forEach { block ->
reader.loadRecordBatch(block)
val root = reader.vectorSchemaRoot
val schema = root.schema
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
add(df)
}
}
is ArrowStreamReader -> {
val root = reader.vectorSchemaRoot
val schema = root.schema
while (reader.loadNextBatch()) {
val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame()
add(df)
}
}
}
}
return flattened.concatKeepingSchema()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import org.apache.arrow.vector.TimeStampMilliVector
import org.apache.arrow.vector.TimeStampNanoVector
import org.apache.arrow.vector.TimeStampSecVector
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowFileReader
import org.apache.arrow.vector.ipc.ArrowFileWriter
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.FieldType
import org.apache.arrow.vector.types.pojo.Schema
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.arrow.vector.util.Text
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
Expand All @@ -32,6 +35,7 @@ import org.jetbrains.kotlinx.dataframe.api.remove
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.File
import java.net.URL
Expand Down Expand Up @@ -553,4 +557,30 @@ internal class ArrowKtTest {
}
}
}

@Test
fun testArrowReaderExtension() {
val dates = listOf(
LocalDateTime.of(2023, 11, 23, 9, 30, 25),
LocalDateTime.of(2015, 5, 25, 14, 20, 13),
LocalDateTime.of(2013, 6, 19, 11, 20, 13),
LocalDateTime.of(2000, 1, 1, 0, 0, 0)
)

val expected = dataFrameOf(
"string" to listOf("a", "b", "c", "d"),
"int" to listOf(1, 2, 3, 4),
"float" to listOf(1.0f, 2.0f, 3.0f, 4.0f),
"double" to listOf(1.0, 2.0, 3.0, 4.0),
"datetime" to dates
)

val featherChannel = ByteArrayReadableSeekableByteChannel(expected.saveArrowFeatherToByteArray())
val arrowFileReader = ArrowFileReader(featherChannel, RootAllocator())
arrowFileReader.toDataFrame() shouldBe expected

val ipcInputStream = ByteArrayInputStream(expected.saveArrowIPCToByteArray())
val arrowStreamReader = ArrowStreamReader(ipcInputStream, RootAllocator())
arrowStreamReader.toDataFrame() shouldBe expected
}
}
2 changes: 1 addition & 1 deletion docs/StardustDocs/topics/read.md
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ val df = DataFrame.readArrowFeather(file)

[`DataFrame`](DataFrame.md) supports reading [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format)
and [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files)
from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), InputStream, File or ByteArray.
from raw Channel (ReadableByteChannel for streaming and SeekableByteChannel for random access), ArrowReader, InputStream, File or ByteArray.

> If you use Java 9+, follow the [Apache Arrow Java compatibility](https://arrow.apache.org/docs/java/install.html#java-compatibility) guide.
>
Expand Down