Skip to content

Commit

Permalink
snowflake-stagingclient-enh
Browse files Browse the repository at this point in the history
  • Loading branch information
gisripa committed Jun 5, 2024
1 parent 3d8f7ca commit 459c37c
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins {
}

airbyteJavaConnector {
cdkVersionRequired = '0.35.14'
cdkVersionRequired = '0.35.15'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data:
connectorSubtype: database
connectorType: destination
definitionId: 424892c4-daac-4491-b35d-c6688ba547ba
dockerImageTag: 3.9.0
dockerImageTag: 3.9.1
dockerRepository: airbyte/destination-snowflake
documentationUrl: https://docs.airbyte.com/integrations/destinations/snowflake
githubIssueLabel: destination-snowflake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
package io.airbyte.integrations.destination.snowflake.operation

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.db.jdbc.JdbcDatabase
import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer
import io.airbyte.commons.string.Strings.join
Expand All @@ -18,6 +19,22 @@ private val log = KotlinLogging.logger {}
/** Client wrapper providing Snowflake Stage related operations. */
class SnowflakeStagingClient(private val database: JdbcDatabase) {

private data class CopyIntoTableResult(
val file: String,
val copyStatus: CopyStatus,
val rowsParsed: Int,
val rowsLoaded: Int,
val errorsSeen: Int,
val firstError: String?
)

private enum class CopyStatus {
UNKNOWN,
LOADED,
LOAD_FAILED,
PARTIALLY_LOADED
}

// Most of the code here is preserved from
// https://github.com/airbytehq/airbyte/blob/503b819b846663b0dff4c90322d0219a93e61d14/airbyte-integrations/connectors/destination-snowflake/src/main/java/io/airbyte/integrations/destination/snowflake/SnowflakeInternalStagingSqlOperations.java
@Throws(IOException::class)
Expand Down Expand Up @@ -63,8 +80,18 @@ class SnowflakeStagingClient(private val database: JdbcDatabase) {
recordsData: SerializableBuffer
) {
val query = getPutQuery(stageName, stagingPath, recordsData.file!!.absolutePath)
log.info { "Executing query: $query" }
database.execute(query)
val queryId = UUID.randomUUID()
log.info { "executing query $queryId, $query" }
val results = database.queryJsons(query)
if (results.isNotEmpty() && (results.first().has("source_size"))) {
if (results.first().get("source_size").asLong() == 0L) {
// TODO: Should we break the Sync rather than proceeding with empty file for COPY ?
log.warn {
"query $queryId, uploaded an empty file, no new records will be inserted"
}
}
}
log.info { "query $queryId, completed with $results" }
if (!checkStageObjectExists(stageName, stagingPath, recordsData.filename)) {
log.error {
"Failed to upload data into stage, object @${
Expand All @@ -84,7 +111,8 @@ class SnowflakeStagingClient(private val database: JdbcDatabase) {
filePath,
stageName,
stagingPath,
Runtime.getRuntime().availableProcessors()
// max allowed param is 99, we don't need so many threads for a single file upload
minOf(Runtime.getRuntime().availableProcessors(), 4)
)
}

Expand Down Expand Up @@ -144,14 +172,72 @@ class SnowflakeStagingClient(private val database: JdbcDatabase) {
streamId: StreamId
) {
try {
val queryId = UUID.randomUUID()
val query = getCopyQuery(stageName, stagingPath, stagedFiles, streamId)
log.info { "Executing query: $query" }
database.execute(query)
log.info { "query $queryId, $query" }
// queryJsons is intentionally used here to get the error message in case of failure
// instead of execute
val results = database.queryJsons(query)
if (results.isNotEmpty()) {
// There will be only one row returned as the result of COPY INTO query
val copyResult = getCopyResult(results.first())
when (copyResult.copyStatus) {
CopyStatus.LOADED ->
log.info {
"query $queryId, successfully loaded ${copyResult.rowsLoaded} rows of data into table"
}
CopyStatus.LOAD_FAILED -> {
log.error {
"query $queryId, failed to load data into table, " +
"rows_parsed: ${copyResult.rowsParsed}, " +
"rows_loaded: ${copyResult.rowsLoaded} " +
"errors: ${copyResult.errorsSeen}, " +
"firstError: ${copyResult.firstError}"
}
throw Exception(
"COPY into table failed with ${copyResult.errorsSeen} errors, check logs"
)
}
else -> log.warn { "query $queryId, unrecognized result format, $results" }
}
} else {
log.warn { "query $queryId, no result returned" }
}
} catch (e: SQLException) {
throw SnowflakeDatabaseUtils.checkForKnownConfigExceptions(e).orElseThrow { e }
}
}

private fun getCopyResult(result: JsonNode): CopyIntoTableResult {
if (
result.has("file") &&
result.has("status") &&
result.has("rows_parsed") &&
result.has("rows_loaded") &&
result.has("errors_seen")
) {
val status =
when (result.get("status").asText()) {
"LOADED" -> CopyStatus.LOADED
"LOAD_FAILED" -> CopyStatus.LOAD_FAILED
"PARTIALLY_LOADED" -> CopyStatus.PARTIALLY_LOADED
else -> CopyStatus.UNKNOWN
}
return CopyIntoTableResult(
result.get("file").asText(),
status,
result.get("rows_parsed").asInt(),
result.get("rows_loaded").asInt(),
result.get("errors_seen").asInt(),
if (result.has("first_error")) result.get("first_error").asText() else null
)
} else {
// Safety in case snowflake decides to change the response format
// instead of blowing up, we return a default object
return CopyIntoTableResult("", CopyStatus.UNKNOWN, 0, 0, 0, null)
}
}

/**
* Creates a SQL query to bulk copy data into fully qualified destination table See
* https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html for more context
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake.operation

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.db.jdbc.JdbcDatabase
import io.airbyte.cdk.integrations.destination.record_buffer.FileBuffer
import io.airbyte.cdk.integrations.destination.s3.csv.CsvSerializedBuffer
import io.airbyte.cdk.integrations.destination.s3.csv.CsvSheetGenerator
import io.airbyte.commons.json.Jsons
import io.airbyte.commons.string.Strings
import io.airbyte.integrations.base.destination.typing_deduping.StreamId
import io.airbyte.integrations.destination.snowflake.OssCloudEnvVarConsts
import io.airbyte.integrations.destination.snowflake.SnowflakeDatabaseUtils
import io.airbyte.protocol.models.v0.AirbyteRecordMessage
import java.nio.file.Files
import java.nio.file.Paths
import java.time.Instant
import java.util.*
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test

class SnowflakeStagingClientIntegrationTest {

private lateinit var stagingClient: SnowflakeStagingClient
// Not using lateinit to keep spotBugs happy
// since these vars are referenced within the setup
// and generated bytecode as if non-null check
private var namespace: String = ""
private var tablename: String = ""

private lateinit var stageName: String
private val config =
Jsons.deserialize(Files.readString(Paths.get("secrets/1s1t_internal_staging_config.json")))
private val datasource =
SnowflakeDatabaseUtils.createDataSource(config, OssCloudEnvVarConsts.AIRBYTE_OSS)
private val database: JdbcDatabase = SnowflakeDatabaseUtils.getDatabase(datasource)
// Intentionally not using actual columns, since the staging client should be agnostic of these
// and only follow the order of data.

@BeforeEach
fun setUp() {
namespace = Strings.addRandomSuffix("staging_client_test", "_", 5).uppercase()
tablename = "integration_test_raw".uppercase()
val createSchemaQuery = """
CREATE SCHEMA "$namespace"
""".trimIndent()
val createStagingTableQuery =
"""
CREATE TABLE IF NOT EXISTS "$namespace"."$tablename" (
"id" VARCHAR PRIMARY KEY,
"emitted_at" TIMESTAMP WITH TIME ZONE DEFAULT current_timestamp(),
"data" VARIANT
)
""".trimIndent()
stageName = """"$namespace"."${Strings.addRandomSuffix("stage", "_", 5)}""""
stagingClient = SnowflakeStagingClient(database)
database.execute(createSchemaQuery)
stagingClient.createStageIfNotExists(stageName)
database.execute(createStagingTableQuery)
}

@AfterEach
fun tearDown() {
stagingClient.dropStageIfExists(stageName)
database.execute("DROP SCHEMA IF EXISTS \"$namespace\" CASCADE")
}

@Test
fun verifyUploadAndCopyToTableSuccess() {
val csvSheetGenerator =
object : CsvSheetGenerator {
override fun getDataRow(formattedData: JsonNode): List<Any> {
throw NotImplementedError("This method should not be called in this test")
}

override fun getDataRow(id: UUID, recordMessage: AirbyteRecordMessage): List<Any> {
throw NotImplementedError("This method should not be called in this test")
}

override fun getDataRow(
id: UUID,
formattedString: String,
emittedAt: Long,
formattedAirbyteMetaString: String
): List<Any> {
return listOf(id, Instant.ofEpochMilli(emittedAt), formattedString)
}

override fun getHeaderRow(): List<String> {
throw NotImplementedError("This method should not be called in this test")
}
}
val writeBuffer =
CsvSerializedBuffer(
FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX),
csvSheetGenerator,
true,
)
val streamId = StreamId("unused", "unused", namespace, tablename, "unused", "unused")
val stagingPath = "${UUID.randomUUID()}/test/"
writeBuffer.use {
it.accept(""" {"dummyKey": "dummyValue"} """, "", System.currentTimeMillis())
it.accept(""" {"dummyKey": "dummyValue"} """, "", System.currentTimeMillis())
it.flush()
val fileName = stagingClient.uploadRecordsToStage(writeBuffer, stageName, stagingPath)
stagingClient.copyIntoTableFromStage(stageName, stagingPath, listOf(fileName), streamId)
}
val results =
database.queryJsons(
"SELECT * FROM \"${streamId.rawNamespace}\".\"${streamId.rawName}\""
)
assertTrue(results.size == 2)
assertNotNull(results.first().get("id"))
assertNotNull(results.first().get("emitted_at"))
assertNotNull(results.first().get("data"))
}

@Test
fun verifyUploadAndCopyToTableFailureOnMismatchedColumns() {
val mismatchedColumnsSheetGenerator =
object : CsvSheetGenerator {
override fun getDataRow(formattedData: JsonNode): List<Any> {
throw NotImplementedError("This method should not be called in this test")
}

override fun getDataRow(id: UUID, recordMessage: AirbyteRecordMessage): List<Any> {
throw NotImplementedError("This method should not be called in this test")
}

override fun getDataRow(
id: UUID,
formattedString: String,
emittedAt: Long,
formattedAirbyteMetaString: String
): List<Any> {
return listOf(
id,
Instant.ofEpochMilli(emittedAt),
formattedString,
"unknown_data_column"
)
}

override fun getHeaderRow(): List<String> {
throw NotImplementedError("This method should not be called in this test")
}
}
val writeBuffer =
CsvSerializedBuffer(
FileBuffer(CsvSerializedBuffer.CSV_GZ_SUFFIX),
mismatchedColumnsSheetGenerator,
true,
)
val streamId = StreamId("unused", "unused", namespace, tablename, "unused", "unused")
val stagingPath = "${UUID.randomUUID()}/test/"
writeBuffer.use {
it.accept(""" {"dummyKey": "dummyValue"} """, "", System.currentTimeMillis())
it.flush()
val fileName = stagingClient.uploadRecordsToStage(writeBuffer, stageName, stagingPath)
assertThrows(Exception::class.java) {
stagingClient.copyIntoTableFromStage(
stageName,
stagingPath,
listOf(fileName),
streamId
)
}
}
val results =
database.queryJsons(
"SELECT * FROM \"${streamId.rawNamespace}\".\"${streamId.rawName}\""
)
assertTrue(results.isEmpty())
}
}
Loading

0 comments on commit 459c37c

Please sign in to comment.