Skip to content

Commit 63fdb46

Browse files
authored
Merge pull request #516 from fb64/master
Add read of Arrow TimeStamp without timezone as LocalDatetime #515
2 parents f36912c + e36b006 commit 63fdb46

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ import org.apache.arrow.vector.TimeMicroVector
1818
import org.apache.arrow.vector.TimeMilliVector
1919
import org.apache.arrow.vector.TimeNanoVector
2020
import org.apache.arrow.vector.TimeSecVector
21+
import org.apache.arrow.vector.TimeStampMicroVector
22+
import org.apache.arrow.vector.TimeStampMilliVector
23+
import org.apache.arrow.vector.TimeStampNanoVector
24+
import org.apache.arrow.vector.TimeStampSecVector
2125
import org.apache.arrow.vector.TinyIntVector
2226
import org.apache.arrow.vector.UInt1Vector
2327
import org.apache.arrow.vector.UInt2Vector
@@ -130,6 +134,39 @@ private fun TimeMilliVector.values(range: IntRange): List<LocalTime?> = range.ma
130134

131135
private fun TimeSecVector.values(range: IntRange): List<LocalTime?> =
132136
range.map { getObject(it)?.let { LocalTime.ofSecondOfDay(it.toLong()) } }
137+
138+
private fun TimeStampNanoVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
139+
if (isNull(i)) {
140+
null
141+
} else {
142+
getObject(it)
143+
}
144+
}
145+
146+
private fun TimeStampMicroVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
147+
if (isNull(i)) {
148+
null
149+
} else {
150+
getObject(it)
151+
}
152+
}
153+
154+
private fun TimeStampMilliVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
155+
if (isNull(i)) {
156+
null
157+
} else {
158+
getObject(it)
159+
}
160+
}
161+
162+
private fun TimeStampSecVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
163+
if (isNull(i)) {
164+
null
165+
} else {
166+
getObject(it)
167+
}
168+
}
169+
133170
private fun StructVector.values(range: IntRange): List<Map<String, Any?>?> = range.map {
134171
getObject(it)
135172
}
@@ -202,6 +239,10 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
202239
is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
203240
is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
204241
is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
242+
is TimeStampNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
243+
is TimeStampMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
244+
is TimeStampMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
245+
is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
205246
is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
206247
else -> {
207248
throw NotImplementedError("reading from ${vector.javaClass.canonicalName} is not implemented")

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@ package org.jetbrains.kotlinx.dataframe.io
33
import io.kotest.assertions.throwables.shouldThrow
44
import io.kotest.matchers.collections.shouldContain
55
import io.kotest.matchers.shouldBe
6+
import org.apache.arrow.memory.RootAllocator
7+
import org.apache.arrow.vector.TimeStampMicroVector
8+
import org.apache.arrow.vector.TimeStampMilliVector
9+
import org.apache.arrow.vector.TimeStampNanoVector
10+
import org.apache.arrow.vector.TimeStampSecVector
11+
import org.apache.arrow.vector.VectorSchemaRoot
12+
import org.apache.arrow.vector.ipc.ArrowFileWriter
13+
import org.apache.arrow.vector.ipc.ArrowStreamWriter
614
import org.apache.arrow.vector.types.FloatingPointPrecision
15+
import org.apache.arrow.vector.types.TimeUnit
716
import org.apache.arrow.vector.types.pojo.ArrowType
817
import org.apache.arrow.vector.types.pojo.Field
918
import org.apache.arrow.vector.types.pojo.FieldType
@@ -23,10 +32,13 @@ import org.jetbrains.kotlinx.dataframe.api.remove
2332
import org.jetbrains.kotlinx.dataframe.api.toColumn
2433
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
2534
import org.junit.Test
35+
import java.io.ByteArrayOutputStream
2636
import java.io.File
2737
import java.net.URL
38+
import java.nio.channels.Channels
2839
import java.time.LocalDate
2940
import java.time.LocalDateTime
41+
import java.time.ZoneOffset
3042
import java.util.Locale
3143
import kotlin.reflect.typeOf
3244

@@ -459,4 +471,86 @@ internal class ArrowKtTest {
459471
val data = dataFrame.saveArrowFeatherToByteArray()
460472
DataFrame.readArrowFeather(data) shouldBe dataFrame
461473
}
474+
475+
@Test
476+
fun testTimeStamp(){
477+
val dates = listOf(
478+
LocalDateTime.of(2023, 11, 23, 9, 30, 25),
479+
LocalDateTime.of(2015, 5, 25, 14, 20, 13),
480+
LocalDateTime.of(2013, 6, 19, 11, 20, 13)
481+
)
482+
483+
val dataFrame = dataFrameOf(
484+
"ts_nano" to dates,
485+
"ts_micro" to dates,
486+
"ts_milli" to dates,
487+
"ts_sec" to dates
488+
)
489+
490+
DataFrame.readArrowFeather(writeArrowTimestamp(dates)) shouldBe dataFrame
491+
DataFrame.readArrowIPC(writeArrowTimestamp(dates, true)) shouldBe dataFrame
492+
}
493+
494+
private fun writeArrowTimestamp(dates: List<LocalDateTime>, streaming: Boolean = false): ByteArray {
495+
RootAllocator().use { allocator ->
496+
val timeStampMilli = Field(
497+
"ts_milli",
498+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
499+
null
500+
)
501+
502+
val timeStampMicro = Field(
503+
"ts_micro",
504+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MICROSECOND, null)),
505+
null
506+
)
507+
508+
val timeStampNano = Field(
509+
"ts_nano",
510+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.NANOSECOND, null)),
511+
null
512+
)
513+
514+
val timeStampSec = Field(
515+
"ts_sec",
516+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.SECOND, null)),
517+
null
518+
)
519+
val schemaTimeStamp = Schema(
520+
listOf(timeStampNano, timeStampMicro, timeStampMilli, timeStampSec)
521+
)
522+
VectorSchemaRoot.create(schemaTimeStamp, allocator).use { vectorSchemaRoot ->
523+
val timeStampMilliVector = vectorSchemaRoot.getVector("ts_milli") as TimeStampMilliVector
524+
val timeStampNanoVector = vectorSchemaRoot.getVector("ts_nano") as TimeStampNanoVector
525+
val timeStampMicroVector = vectorSchemaRoot.getVector("ts_micro") as TimeStampMicroVector
526+
val timeStampSecVector = vectorSchemaRoot.getVector("ts_sec") as TimeStampSecVector
527+
timeStampMilliVector.allocateNew(dates.size)
528+
timeStampNanoVector.allocateNew(dates.size)
529+
timeStampMicroVector.allocateNew(dates.size)
530+
timeStampSecVector.allocateNew(dates.size)
531+
532+
dates.forEachIndexed { index, localDateTime ->
533+
val instant = localDateTime.toInstant(ZoneOffset.UTC)
534+
timeStampNanoVector[index] = instant.toEpochMilli() * 1_000_000L + instant.nano
535+
timeStampMicroVector[index] = instant.toEpochMilli() * 1_000L
536+
timeStampMilliVector[index] = instant.toEpochMilli()
537+
timeStampSecVector[index] = instant.toEpochMilli() / 1_000L
538+
}
539+
vectorSchemaRoot.setRowCount(dates.size)
540+
val bos = ByteArrayOutputStream()
541+
bos.use { out ->
542+
val arrowWriter = if (streaming) {
543+
ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(out))
544+
} else {
545+
ArrowFileWriter(vectorSchemaRoot, null, Channels.newChannel(out))
546+
}
547+
arrowWriter.use { writer ->
548+
writer.start()
549+
writer.writeBatch()
550+
}
551+
}
552+
return bos.toByteArray()
553+
}
554+
}
555+
}
462556
}

0 commit comments

Comments
 (0)