Skip to content

Commit

Permalink
Destinations CDK: CatalogParser sets default namespace (#38121)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao authored Jun 10, 2024
1 parent c019d32 commit a78647e
Show file tree
Hide file tree
Showing 18 changed files with 123 additions and 177 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ corresponds to that version.

| Version | Date | Pull Request | Subject |
|:--------|:-----------|:-----------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.37.0 | 2024-06-10 | [\#38121](https://github.com/airbytehq/airbyte/pull/38121) | Destinations: Set default namespace via CatalogParser |
| 0.36.8 | 2024-06-07 | [\#38763](https://github.com/airbytehq/airbyte/pull/38763) | Increase Jackson message length limit |
| 0.36.7 | 2024-06-06 | [\#39220](https://github.com/airbytehq/airbyte/pull/39220) | Handle null messages in ConnectorExceptionUtil |
| 0.36.6 | 2024-06-05 | [\#39106](https://github.com/airbytehq/airbyte/pull/39106) | Skip write to storage with 0 byte file |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.base.SerializedAirbyteMessageConsumer
import io.airbyte.cdk.integrations.destination.StreamSyncSummary
import io.airbyte.cdk.integrations.destination.async.buffers.BufferEnqueue
Expand All @@ -28,7 +27,6 @@ import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicLong
import java.util.function.Consumer
import kotlin.jvm.optionals.getOrNull
import org.jetbrains.annotations.VisibleForTesting

private val logger = KotlinLogging.logger {}
Expand All @@ -51,7 +49,6 @@ constructor(
onFlush: DestinationFlushFunction,
private val catalog: ConfiguredAirbyteCatalog,
private val bufferManager: BufferManager,
private val defaultNamespace: Optional<String>,
private val flushFailure: FlushFailure = FlushFailure(),
workerPool: ExecutorService = Executors.newFixedThreadPool(5),
private val airbyteMessageDeserializer: AirbyteMessageDeserializer =
Expand Down Expand Up @@ -79,28 +76,6 @@ constructor(
private var hasClosed = false
private var hasFailed = false

internal constructor(
outputRecordCollector: Consumer<AirbyteMessage>,
onStart: OnStartFunction,
onClose: OnCloseFunction,
flusher: DestinationFlushFunction,
catalog: ConfiguredAirbyteCatalog,
bufferManager: BufferManager,
flushFailure: FlushFailure,
defaultNamespace: Optional<String>,
) : this(
outputRecordCollector,
onStart,
onClose,
flusher,
catalog,
bufferManager,
defaultNamespace,
flushFailure,
Executors.newFixedThreadPool(5),
AirbyteMessageDeserializer(),
)

@Throws(Exception::class)
override fun start() {
Preconditions.checkState(!hasStarted, "Consumer has already been started.")
Expand Down Expand Up @@ -129,9 +104,6 @@ constructor(
message,
)
if (AirbyteMessage.Type.RECORD == partialAirbyteMessage.type) {
if (Strings.isNullOrEmpty(partialAirbyteMessage.record?.namespace)) {
partialAirbyteMessage.record?.namespace = defaultNamespace.getOrNull()
}
validateRecord(partialAirbyteMessage)

partialAirbyteMessage.record?.streamDescriptor?.let {
Expand All @@ -141,7 +113,6 @@ constructor(
bufferEnqueue.addRecord(
partialAirbyteMessage,
sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES,
defaultNamespace,
)
}

Expand All @@ -159,10 +130,14 @@ constructor(
bufferManager.close()

val streamSyncSummaries =
streamNames.associateWith { streamDescriptor: StreamDescriptor ->
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
streamNames.associate { streamDescriptor ->
StreamDescriptorUtils.withDefaultNamespace(
streamDescriptor,
bufferManager.defaultNamespace,
) to
StreamSyncSummary(
Optional.of(getRecordCounter(streamDescriptor).get()),
)
}
onClose.accept(hasFailed, streamSyncSummaries)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,11 @@ object StreamDescriptorUtils {

return pairs
}

fun withDefaultNamespace(sd: StreamDescriptor, defaultNamespace: String) =
if (sd.namespace.isNullOrEmpty()) {
StreamDescriptor().withName(sd.name).withNamespace(defaultNamespace)
} else {
sd
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package io.airbyte.cdk.integrations.destination.async.buffers
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.cdk.integrations.destination.async.state.GlobalAsyncStateManager
import io.airbyte.commons.json.Jsons
import io.airbyte.protocol.models.v0.AirbyteMessage
import io.airbyte.protocol.models.v0.StreamDescriptor
import java.util.Optional
import java.util.concurrent.ConcurrentMap

/**
Expand All @@ -20,6 +20,7 @@ class BufferEnqueue(
private val memoryManager: GlobalMemoryManager,
private val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>,
private val stateManager: GlobalAsyncStateManager,
private val defaultNamespace: String,
) {
/**
* Buffer a record. Contains memory management logic to dynamically adjust queue size based via
Expand All @@ -31,12 +32,11 @@ class BufferEnqueue(
fun addRecord(
message: PartialAirbyteMessage,
sizeInBytes: Int,
defaultNamespace: Optional<String>,
) {
if (message.type == AirbyteMessage.Type.RECORD) {
handleRecord(message, sizeInBytes)
} else if (message.type == AirbyteMessage.Type.STATE) {
stateManager.trackState(message, sizeInBytes.toLong(), defaultNamespace.orElse(""))
stateManager.trackState(message, sizeInBytes.toLong())
}
}

Expand All @@ -53,15 +53,28 @@ class BufferEnqueue(
}
val stateId = stateManager.getStateIdAndIncrementCounter(streamDescriptor)

var addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
// We don't set the default namespace until after putting this message into the state
// manager/etc.
// All our internal handling is on the true (null) namespace,
// we just set the default namespace when handing off to destination-specific code.
val mangledMessage =
if (message.record!!.namespace.isNullOrEmpty()) {
val clone = Jsons.clone(message)
clone.record!!.namespace = defaultNamespace
clone
} else {
message
}

var addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)

var i = 0
while (!addedToQueue) {
val newlyAllocatedMemory = memoryManager.requestMemory()
if (newlyAllocatedMemory > 0) {
queue.addMaxMemory(newlyAllocatedMemory)
}
addedToQueue = queue.offer(message, sizeInBytes.toLong(), stateId)
addedToQueue = queue.offer(mangledMessage, sizeInBytes.toLong(), stateId)
i++
if (i > 5) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ private val logger = KotlinLogging.logger {}
class BufferManager
@JvmOverloads
constructor(
/**
* This probably doesn't belong here, but it's the easiest place where both [BufferEnqueue] and
* [io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer] can both get to it.
*/
public val defaultNamespace: String,
maxMemory: Long = (Runtime.getRuntime().maxMemory() * MEMORY_LIMIT_RATIO).toLong(),
) {
@get:VisibleForTesting val buffers: ConcurrentMap<StreamDescriptor, StreamAwareQueue>
Expand All @@ -46,7 +51,7 @@ constructor(
memoryManager = GlobalMemoryManager(maxMemory)
this.stateManager = GlobalAsyncStateManager(memoryManager)
buffers = ConcurrentHashMap()
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager)
bufferEnqueue = BufferEnqueue(memoryManager, buffers, stateManager, defaultNamespace)
bufferDequeue = BufferDequeue(memoryManager, buffers, stateManager)
debugLoop = Executors.newSingleThreadScheduledExecutor()
debugLoop.scheduleAtFixedRate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package io.airbyte.cdk.integrations.destination.async.state

import com.google.common.base.Preconditions
import com.google.common.base.Strings
import io.airbyte.cdk.integrations.destination.async.GlobalMemoryManager
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage
import io.airbyte.commons.json.Jsons
Expand Down Expand Up @@ -104,7 +103,6 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
fun trackState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
if (preState) {
convertToGlobalIfNeeded(message)
Expand All @@ -113,7 +111,7 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
// stateType should not change after a conversion.
Preconditions.checkArgument(stateType == extractStateType(message))

closeState(message, sizeInBytes, defaultNamespace)
closeState(message, sizeInBytes)
}

/**
Expand Down Expand Up @@ -323,10 +321,9 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
private fun closeState(
message: PartialAirbyteMessage,
sizeInBytes: Long,
defaultNamespace: String,
) {
val resolvedDescriptor: StreamDescriptor =
extractStream(message, defaultNamespace)
extractStream(message)
.orElse(
SENTINEL_GLOBAL_DESC,
)
Expand Down Expand Up @@ -424,38 +421,14 @@ class GlobalAsyncStateManager(private val memoryManager: GlobalMemoryManager) {
UUID.randomUUID().toString(),
)

/**
* If the user has selected the Destination Namespace as the Destination default while
* setting up the connector, the platform sets the namespace as null in the StreamDescriptor
* in the AirbyteMessages (both record and state messages). The destination checks that if
* the namespace is empty or null, if yes then re-populates it with the defaultNamespace.
* See [io.airbyte.cdk.integrations.destination.async.AsyncStreamConsumer.accept] But
* destination only does this for the record messages. So when state messages arrive without
* a namespace and since the destination doesn't repopulate it with the default namespace,
* there is a mismatch between the StreamDescriptor from record messages and state messages.
* That breaks the logic of the state management class as [descToStateIdQ] needs to have
* consistent StreamDescriptor. This is why while trying to extract the StreamDescriptor
* from state messages, we check if the namespace is null, if yes then replace it with
* defaultNamespace to keep it consistent with the record messages.
*/
private fun extractStream(
message: PartialAirbyteMessage,
defaultNamespace: String,
): Optional<StreamDescriptor> {
if (
message.state?.type != null &&
message.state?.type == AirbyteStateMessage.AirbyteStateType.STREAM
) {
val streamDescriptor: StreamDescriptor? = message.state?.stream?.streamDescriptor
if (Strings.isNullOrEmpty(streamDescriptor?.namespace)) {
return Optional.of(
StreamDescriptor()
.withName(
streamDescriptor?.name,
)
.withNamespace(defaultNamespace),
)
}
return streamDescriptor?.let { Optional.of(it) } ?: Optional.empty()
}
return Optional.empty()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.36.8
version=0.37.0
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import io.airbyte.protocol.models.v0.StreamDescriptor
import java.io.IOException
import java.math.BigDecimal
import java.time.Instant
import java.util.Optional
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
Expand Down Expand Up @@ -60,7 +59,7 @@ class AsyncStreamConsumerTest {
private val CATALOG: ConfiguredAirbyteCatalog =
ConfiguredAirbyteCatalog()
.withStreams(
java.util.List.of(
listOf(
CatalogHelpers.createConfiguredAirbyteStream(
STREAM_NAME,
SCHEMA_NAME,
Expand Down Expand Up @@ -145,9 +144,8 @@ class AsyncStreamConsumerTest {
onClose = onClose,
onFlush = flushFunction,
catalog = CATALOG,
bufferManager = BufferManager(),
bufferManager = BufferManager("default_ns"),
flushFailure = flushFailure,
defaultNamespace = Optional.of("default_ns"),
airbyteMessageDeserializer = airbyteMessageDeserializer,
workerPool = Executors.newFixedThreadPool(5),
)
Expand Down Expand Up @@ -264,9 +262,8 @@ class AsyncStreamConsumerTest {
Mockito.mock(OnCloseFunction::class.java),
flushFunction,
CATALOG,
BufferManager((1024 * 10).toLong()),
BufferManager("default_ns", (1024 * 10).toLong()),
flushFailure,
Optional.of("default_ns"),
)
Mockito.`when`(flushFunction.optimalBatchSizeBytes).thenReturn(0L)

Expand Down
Loading

0 comments on commit a78647e

Please sign in to comment.