Skip to content

Commit

Permalink
Bulk Load CDK: Cleanup: Files/Objects no longer Batches (#46960)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Oct 18, 2024
1 parent 4a2e7c8 commit 7b12647
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package io.airbyte.cdk.load.message
import com.google.common.collect.Range
import com.google.common.collect.RangeSet
import com.google.common.collect.TreeRangeSet
import io.airbyte.cdk.load.file.LocalFile

/**
* Represents an accumulated batch of records in some stage of processing.
Expand Down Expand Up @@ -47,7 +46,6 @@ import io.airbyte.cdk.load.file.LocalFile
*/
interface Batch {
enum class State {
SPILLED,
LOCAL,
PERSISTED,
COMPLETE
Expand All @@ -66,23 +64,6 @@ interface Batch {
/** Simple batch: use if you need no other metadata for processing. */
data class SimpleBatch(override val state: Batch.State) : Batch

/** Represents a file of records locally staged. */
abstract class StagedLocalFile() : Batch {
abstract val localFile: LocalFile
abstract val totalSizeBytes: Long
override val state: Batch.State = Batch.State.LOCAL
}

/**
* Represents a file of raw records staged to disk for pre-processing. Used internally by the
* framework
*/
data class SpilledRawMessagesLocalFile(
override val localFile: LocalFile,
override val totalSizeBytes: Long,
override val state: Batch.State = Batch.State.SPILLED
) : StagedLocalFile()

/**
* Internally-used wrapper for tracking the association between a batch and the range of records it
* contains.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import io.airbyte.cdk.load.command.DestinationCatalog
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.implementor.CloseStreamTaskFactory
import io.airbyte.cdk.load.task.implementor.OpenStreamTaskFactory
Expand All @@ -20,6 +19,7 @@ import io.airbyte.cdk.load.task.implementor.TeardownTaskFactory
import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory
import io.airbyte.cdk.load.task.internal.InputConsumerTask
import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.task.internal.TimedForcedCheckpointFlushTask
import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask
import io.airbyte.cdk.load.util.setOnce
Expand All @@ -34,11 +34,7 @@ import kotlinx.coroutines.sync.withLock
interface DestinationTaskLauncher : TaskLauncher {
suspend fun handleSetupComplete()
suspend fun handleStreamStarted(stream: DestinationStream)
suspend fun handleNewSpilledFile(
stream: DestinationStream,
wrapped: BatchEnvelope<SpilledRawMessagesLocalFile>,
endOfStream: Boolean
)
suspend fun handleNewSpilledFile(stream: DestinationStream, file: SpilledRawMessagesLocalFile)
suspend fun handleNewBatch(stream: DestinationStream, wrapped: BatchEnvelope<*>)
suspend fun handleStreamClosed(stream: DestinationStream)
suspend fun handleTeardownComplete()
Expand Down Expand Up @@ -168,13 +164,12 @@ class DefaultDestinationTaskLauncher(
/** Called for each new spilled file. */
override suspend fun handleNewSpilledFile(
stream: DestinationStream,
wrapped: BatchEnvelope<SpilledRawMessagesLocalFile>,
endOfStream: Boolean
file: SpilledRawMessagesLocalFile
) {
log.info { "Starting process records task for ${stream.descriptor}, file ${wrapped.batch}" }
val task = processRecordsTaskFactory.make(this, stream, wrapped)
log.info { "Starting process records task for ${stream.descriptor}, file $file" }
val task = processRecordsTaskFactory.make(this, stream, file)
enqueue(task)
if (!endOfStream) {
if (!file.endOfStream) {
log.info { "End-of-stream not reached, restarting spill-to-disk task for $stream" }
val spillTask = spillToDiskTaskFactory.make(this, stream)
enqueue(spillTask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.DestinationStreamAffinedMessage
import io.airbyte.cdk.load.message.DestinationStreamComplete
import io.airbyte.cdk.load.message.DestinationStreamIncomplete
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.DestinationTaskLauncher
import io.airbyte.cdk.load.task.ImplementorScope
import io.airbyte.cdk.load.task.StreamLevel
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.write.StreamLoader
import io.github.oshai.kotlinlogging.KotlinLogging
import io.micronaut.context.annotation.Secondary
Expand All @@ -35,7 +35,7 @@ interface ProcessRecordsTask : StreamLevel, ImplementorScope
class DefaultProcessRecordsTask(
override val stream: DestinationStream,
private val taskLauncher: DestinationTaskLauncher,
private val fileEnvelope: BatchEnvelope<SpilledRawMessagesLocalFile>,
private val file: SpilledRawMessagesLocalFile,
private val deserializer: Deserializer<DestinationMessage>,
private val syncManager: SyncManager,
) : ProcessRecordsTask {
Expand All @@ -45,10 +45,10 @@ class DefaultProcessRecordsTask(
log.info { "Fetching stream loader for ${stream.descriptor}" }
val streamLoader = syncManager.getOrAwaitStreamLoader(stream.descriptor)

log.info { "Processing records from ${fileEnvelope.batch.localFile}" }
val nextBatch =
log.info { "Processing records from $file" }
val batch =
try {
fileEnvelope.batch.localFile.toFileReader().use { reader ->
file.localFile.toFileReader().use { reader ->
val records =
reader
.lines()
Expand All @@ -67,14 +67,14 @@ class DefaultProcessRecordsTask(
}
.map { it as DestinationRecord }
.iterator()
streamLoader.processRecords(records, fileEnvelope.batch.totalSizeBytes)
streamLoader.processRecords(records, file.totalSizeBytes)
}
} finally {
log.info { "Processing completed, deleting ${fileEnvelope.batch.localFile}" }
fileEnvelope.batch.localFile.delete()
log.info { "Processing completed, deleting $file" }
file.localFile.delete()
}

val wrapped = fileEnvelope.withBatch(nextBatch)
val wrapped = BatchEnvelope(batch, file.indexRange)
taskLauncher.handleNewBatch(stream, wrapped)
}
}
Expand All @@ -83,7 +83,7 @@ interface ProcessRecordsTaskFactory {
fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream,
fileEnvelope: BatchEnvelope<SpilledRawMessagesLocalFile>,
file: SpilledRawMessagesLocalFile,
): ProcessRecordsTask
}

Expand All @@ -96,14 +96,8 @@ class DefaultProcessRecordsTaskFactory(
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream,
fileEnvelope: BatchEnvelope<SpilledRawMessagesLocalFile>,
file: SpilledRawMessagesLocalFile,
): ProcessRecordsTask {
return DefaultProcessRecordsTask(
stream,
taskLauncher,
fileEnvelope,
deserializer,
syncManager
)
return DefaultProcessRecordsTask(stream, taskLauncher, file, deserializer, syncManager)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ package io.airbyte.cdk.load.task.internal
import com.google.common.collect.Range
import io.airbyte.cdk.load.command.DestinationConfiguration
import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.file.LocalFile
import io.airbyte.cdk.load.file.TempFileProvider
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.DestinationRecordWrapped
import io.airbyte.cdk.load.message.MessageQueueSupplier
import io.airbyte.cdk.load.message.QueueReader
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.message.StreamCompleteWrapped
import io.airbyte.cdk.load.message.StreamRecordWrapped
import io.airbyte.cdk.load.state.FlushStrategy
Expand Down Expand Up @@ -98,9 +97,8 @@ class DefaultSpillToDiskTask(
return
}

val batch = SpilledRawMessagesLocalFile(tmpFile, sizeBytes)
val wrapped = BatchEnvelope(batch, range)
launcher.handleNewSpilledFile(stream, wrapped, endOfStream)
val file = SpilledRawMessagesLocalFile(tmpFile, sizeBytes, range, endOfStream)
launcher.handleNewSpilledFile(stream, file)
}
}

Expand Down Expand Up @@ -130,3 +128,10 @@ class DefaultSpillToDiskTaskFactory(
)
}
}

data class SpilledRawMessagesLocalFile(
val localFile: LocalFile,
val totalSizeBytes: Long,
val indexRange: Range<Long>,
val endOfStream: Boolean = false
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import io.airbyte.cdk.load.command.MockDestinationCatalogFactory
import io.airbyte.cdk.load.file.DefaultLocalFile
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.implementor.CloseStreamTask
import io.airbyte.cdk.load.task.implementor.CloseStreamTaskFactory
Expand All @@ -38,6 +37,7 @@ import io.airbyte.cdk.load.task.internal.FlushCheckpointsTaskFactory
import io.airbyte.cdk.load.task.internal.InputConsumerTask
import io.airbyte.cdk.load.task.internal.SpillToDiskTask
import io.airbyte.cdk.load.task.internal.SpillToDiskTaskFactory
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.task.internal.TimedForcedCheckpointFlushTask
import io.airbyte.cdk.load.task.internal.UpdateCheckpointsTask
import io.micronaut.context.annotation.Primary
Expand Down Expand Up @@ -167,7 +167,7 @@ class DestinationTaskLauncherTest<T> where T : LeveledTask, T : ScopedTask {
override fun make(
taskLauncher: DestinationTaskLauncher,
stream: DestinationStream,
fileEnvelope: BatchEnvelope<SpilledRawMessagesLocalFile>
file: SpilledRawMessagesLocalFile
): ProcessRecordsTask {
return object : ProcessRecordsTask {
override val stream: DestinationStream = stream
Expand Down Expand Up @@ -354,10 +354,11 @@ class DestinationTaskLauncherTest<T> where T : LeveledTask, T : ScopedTask {
fun testHandleSpilledFileCompleteNotEndOfStream() = runTest {
taskLauncher.handleNewSpilledFile(
MockDestinationCatalogFactory.stream1,
BatchEnvelope(
SpilledRawMessagesLocalFile(DefaultLocalFile(Path("not/a/real/file")), 100L)
),
false
SpilledRawMessagesLocalFile(
DefaultLocalFile(Path("not/a/real/file")),
100L,
Range.singleton(0)
)
)

processRecordsTaskFactory.hasRun.receive()
Expand All @@ -371,10 +372,12 @@ class DestinationTaskLauncherTest<T> where T : LeveledTask, T : ScopedTask {
launch {
taskLauncher.handleNewSpilledFile(
MockDestinationCatalogFactory.stream1,
BatchEnvelope(
SpilledRawMessagesLocalFile(DefaultLocalFile(Path("not/a/real/file")), 100L)
),
true
SpilledRawMessagesLocalFile(
DefaultLocalFile(Path("not/a/real/file")),
100L,
Range.singleton(0),
true
)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package io.airbyte.cdk.load.task

import io.airbyte.cdk.load.command.DestinationStream
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
import jakarta.inject.Singleton
Expand All @@ -15,7 +15,7 @@ import jakarta.inject.Singleton
@Primary
@Requires(env = ["MockTaskLauncher"])
class MockTaskLauncher : DestinationTaskLauncher {
val spilledFiles = mutableListOf<BatchEnvelope<SpilledRawMessagesLocalFile>>()
val spilledFiles = mutableListOf<SpilledRawMessagesLocalFile>()
val batchEnvelopes = mutableListOf<BatchEnvelope<*>>()

override suspend fun handleSetupComplete() {
Expand All @@ -28,10 +28,9 @@ class MockTaskLauncher : DestinationTaskLauncher {

override suspend fun handleNewSpilledFile(
stream: DestinationStream,
wrapped: BatchEnvelope<SpilledRawMessagesLocalFile>,
endOfStream: Boolean
file: SpilledRawMessagesLocalFile
) {
spilledFiles.add(wrapped)
spilledFiles.add(file)
}

override suspend fun handleNewBatch(stream: DestinationStream, wrapped: BatchEnvelope<*>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@ import io.airbyte.cdk.load.command.MockDestinationCatalogFactory
import io.airbyte.cdk.load.data.IntegerValue
import io.airbyte.cdk.load.file.MockTempFileProvider
import io.airbyte.cdk.load.message.Batch
import io.airbyte.cdk.load.message.BatchEnvelope
import io.airbyte.cdk.load.message.Deserializer
import io.airbyte.cdk.load.message.DestinationMessage
import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.message.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.state.SyncManager
import io.airbyte.cdk.load.task.MockTaskLauncher
import io.airbyte.cdk.load.task.internal.SpilledRawMessagesLocalFile
import io.airbyte.cdk.load.write.StreamLoader
import io.micronaut.context.annotation.Primary
import io.micronaut.context.annotation.Requires
Expand Down Expand Up @@ -100,12 +99,13 @@ class ProcessRecordsTaskTest {
SpilledRawMessagesLocalFile(
localFile = mockFile,
totalSizeBytes = byteSize,
indexRange = Range.closed(0, recordCount)
)
val task =
processRecordsTaskFactory.make(
taskLauncher = launcher,
stream = MockDestinationCatalogFactory.stream1,
fileEnvelope = BatchEnvelope(file, Range.closed(0, 1024))
file = file
)
mockFile.linesToRead = (0 until recordCount).map { "$it" }.toMutableList()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ class SpillToDiskTaskTest {
.execute()
Assertions.assertEquals(2, mockTaskLauncher.spilledFiles.size)

Assertions.assertEquals(1024, mockTaskLauncher.spilledFiles[0].batch.totalSizeBytes)
Assertions.assertEquals(512, mockTaskLauncher.spilledFiles[1].batch.totalSizeBytes)
Assertions.assertEquals(1024, mockTaskLauncher.spilledFiles[0].totalSizeBytes)
Assertions.assertEquals(512, mockTaskLauncher.spilledFiles[1].totalSizeBytes)

val env1 = mockTaskLauncher.spilledFiles[0]
val env2 = mockTaskLauncher.spilledFiles[1]
Assertions.assertEquals(1024, env1.batch.totalSizeBytes)
Assertions.assertEquals(512, env2.batch.totalSizeBytes)
val spilled1 = mockTaskLauncher.spilledFiles[0]
val spilled2 = mockTaskLauncher.spilledFiles[1]
Assertions.assertEquals(1024, spilled1.totalSizeBytes)
Assertions.assertEquals(512, spilled2.totalSizeBytes)

val file1 = env1.batch.localFile as MockTempFileProvider.MockLocalFile
val file2 = env2.batch.localFile as MockTempFileProvider.MockLocalFile
val file1 = spilled1.localFile as MockTempFileProvider.MockLocalFile
val file2 = spilled2.localFile as MockTempFileProvider.MockLocalFile
Assertions.assertTrue(file1.writersCreated[0].isClosed)
Assertions.assertTrue(file2.writersCreated[0].isClosed)

Expand Down

0 comments on commit 7b12647

Please sign in to comment.