Skip to content

Improved SQL<->JDBC mapping #855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 16, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public object MariaDb : DbType("mariadb") {
get() = "org.mariadb.jdbc.Driver"

override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
if (tableColumnMetadata.sqlTypeName == "SMALLINT") {
if (tableColumnMetadata.sqlTypeName == "SMALLINT" && tableColumnMetadata.javaClassName == "java.lang.Short") {
val kType = Short::class.createType(nullable = tableColumnMetadata.isNullable)
return ColumnSchema.Value(kType)
}
Expand All @@ -35,7 +35,7 @@ public object MariaDb : DbType("mariadb") {
)

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
if (tableColumnMetadata.sqlTypeName == "SMALLINT") {
if (tableColumnMetadata.sqlTypeName == "SMALLINT" && tableColumnMetadata.javaClassName == "java.lang.Short") {
return Short::class.createType(nullable = tableColumnMetadata.isNullable)
}
return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ import java.sql.SQLXML
import java.sql.Time
import java.sql.Timestamp
import java.sql.Types
import java.time.LocalDateTime
import java.time.OffsetDateTime
import java.time.OffsetTime
import java.util.Date
import java.util.UUID
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.createType
import kotlin.reflect.full.isSupertypeOf
import kotlin.reflect.full.safeCast
import kotlin.reflect.full.starProjectedType

private val logger = KotlinLogging.logger {}
Expand Down Expand Up @@ -775,6 +781,10 @@ private fun manageColumnNameDuplication(columnNameCounter: MutableMap<String, In
return name
}

// Utility function to cast arrays based on the type of elements
private fun <T : Any> castArray(array: Array<*>, elementType: KClass<T>): List<T> =
array.mapNotNull { elementType.safeCast(it) }

/**
* Fetches and converts data from a ResultSet into a mutable map.
*
Expand Down Expand Up @@ -816,9 +826,15 @@ private fun fetchAndConvertDataFromResultSet(
}

val dataFrame = data.mapIndexed { index, values ->
val correctedValues = if (kotlinTypesForSqlColumns[index]!!.classifier == Array::class) {
handleArrayValues(values)
} else {
values
}

DataColumn.createValueColumn(
name = tableColumns[index].name,
values = values,
values = correctedValues,
infer = convertNullabilityInference(inferNullability),
type = kotlinTypesForSqlColumns[index]!!,
)
Expand All @@ -831,6 +847,31 @@ private fun fetchAndConvertDataFromResultSet(
return dataFrame
}

private fun handleArrayValues(values: MutableList<Any?>): List<Any> {
// Intermediate variable for the first mapping
val sqlArrays = values.mapNotNull {
(it as? java.sql.Array)?.array?.let { array -> array as? Array<*> }
}

// Flatten the arrays to iterate through all elements and filter out null values, then map to component types
val allElementTypes = sqlArrays
.flatMap { array ->
(array.javaClass.componentType?.kotlin?.let { listOf(it) } ?: emptyList())
} // Get the component type of each array and convert it to a Kotlin class, if available

// Find distinct types and ensure there's only one distinct type
val commonElementType = allElementTypes
.distinct() // Get unique element types
.singleOrNull() // Ensure there's only one unique element type, otherwise return null
?: Any::class // Fallback to Any::class if multiple distinct types or no elements found

return if (commonElementType != Any::class) {
sqlArrays.map { castArray(it, commonElementType).toTypedArray() }
} else {
sqlArrays
}
}

private fun convertNullabilityInference(inferNullability: Boolean) = if (inferNullability) Infer.Nulls else Infer.None

private fun extractNewRowFromResultSetAndAddToData(
Expand All @@ -843,6 +884,7 @@ private fun extractNewRowFromResultSetAndAddToData(
data[i].add(
try {
rs.getObject(i + 1)
// TODO: add a special handler for Blob via Streams
} catch (_: Throwable) {
val kType = kotlinTypesForSqlColumns[i]!!
// TODO: expand for all the types like in generateKType function
Expand All @@ -868,7 +910,7 @@ private fun generateKType(dbType: DbType, tableColumnMetadata: TableColumnMetada
* Creates a mapping between common SQL types and their corresponding KTypes.
*
* @param tableColumnMetadata The metadata of the table column.
* @return The KType associated with the SQL type, or a default type if no mapping is found.
* @return The KType associated with the SQL type or a default type if no mapping is found.
*/
private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata): KType {
val jdbcTypeToKTypeMapping = mapOf(
Expand All @@ -882,7 +924,7 @@ private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata
Types.DOUBLE to Double::class,
Types.NUMERIC to BigDecimal::class,
Types.DECIMAL to BigDecimal::class,
Types.CHAR to Char::class,
Types.CHAR to String::class,
Types.VARCHAR to String::class,
Types.LONGVARCHAR to String::class,
Types.DATE to Date::class,
Expand All @@ -892,27 +934,67 @@ private fun makeCommonSqlToKTypeMapping(tableColumnMetadata: TableColumnMetadata
Types.VARBINARY to ByteArray::class,
Types.LONGVARBINARY to ByteArray::class,
Types.NULL to String::class,
Types.OTHER to Any::class,
Types.JAVA_OBJECT to Any::class,
Types.DISTINCT to Any::class,
Types.STRUCT to Any::class,
Types.ARRAY to Array<Any>::class,
Types.BLOB to Blob::class,
Types.ARRAY to Array::class,
Types.BLOB to ByteArray::class,
Types.CLOB to Clob::class,
Types.REF to Ref::class,
Types.DATALINK to Any::class,
Types.BOOLEAN to Boolean::class,
Types.ROWID to RowId::class,
Types.NCHAR to Char::class,
Types.NCHAR to String::class,
Types.NVARCHAR to String::class,
Types.LONGNVARCHAR to String::class,
Types.NCLOB to NClob::class,
Types.SQLXML to SQLXML::class,
Types.REF_CURSOR to Ref::class,
Types.TIME_WITH_TIMEZONE to Time::class,
Types.TIMESTAMP_WITH_TIMEZONE to Timestamp::class,
Types.TIME_WITH_TIMEZONE to OffsetTime::class,
Types.TIMESTAMP_WITH_TIMEZONE to OffsetDateTime::class,
)
// TODO: check mapping of JDBC types and classes correctly
val kClass = jdbcTypeToKTypeMapping[tableColumnMetadata.jdbcType] ?: String::class
return kClass.createType(nullable = tableColumnMetadata.isNullable)

fun determineKotlinClass(tableColumnMetadata: TableColumnMetadata): KClass<*> =
when {
tableColumnMetadata.jdbcType == Types.OTHER -> when (tableColumnMetadata.javaClassName) {
"[B" -> ByteArray::class
else -> Any::class
}

tableColumnMetadata.javaClassName == "[B" -> ByteArray::class

tableColumnMetadata.javaClassName == "java.sql.Blob" -> Blob::class

tableColumnMetadata.jdbcType == Types.TIMESTAMP &&
tableColumnMetadata.javaClassName == "java.time.LocalDateTime" -> LocalDateTime::class

tableColumnMetadata.jdbcType == Types.BINARY &&
tableColumnMetadata.javaClassName == "java.util.UUID" -> UUID::class

tableColumnMetadata.jdbcType == Types.REAL &&
tableColumnMetadata.javaClassName == "java.lang.Double" -> Double::class

tableColumnMetadata.jdbcType == Types.FLOAT &&
tableColumnMetadata.javaClassName == "java.lang.Double" -> Double::class

tableColumnMetadata.jdbcType == Types.NUMERIC &&
tableColumnMetadata.javaClassName == "java.lang.Double" -> Double::class

else -> jdbcTypeToKTypeMapping[tableColumnMetadata.jdbcType] ?: String::class
}

fun createArrayTypeIfNeeded(kClass: KClass<*>, isNullable: Boolean): KType =
if (kClass == Array::class) {
val typeParam = kClass.typeParameters[0].createType()
kClass.createType(
arguments = listOf(kotlin.reflect.KTypeProjection.invariant(typeParam)),
nullable = isNullable,
)
} else {
kClass.createType(nullable = isNullable)
}

val kClass: KClass<*> = determineKotlinClass(tableColumnMetadata)
val kType = createArrayTypeIfNeeded(kClass, tableColumnMetadata.isNullable)
return kType
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class JdbcTest {

val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
dataSchema.columns.size shouldBe 2
dataSchema.columns["characterCol"]!!.type shouldBe typeOf<Char?>()
dataSchema.columns["characterCol"]!!.type shouldBe typeOf<String?>()
}

@Test
Expand Down Expand Up @@ -291,6 +291,7 @@ class JdbcTest {

val schema = DataFrame.getSchemaForSqlTable(connection, tableName)

schema.columns["characterCol"]!!.type shouldBe typeOf<String?>()
schema.columns["tinyIntCol"]!!.type shouldBe typeOf<Int?>()
schema.columns["smallIntCol"]!!.type shouldBe typeOf<Int?>()
schema.columns["bigIntCol"]!!.type shouldBe typeOf<Long?>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ import org.junit.AfterClass
import org.junit.BeforeClass
import org.junit.Test
import java.math.BigDecimal
import java.sql.Blob
import java.sql.Connection
import java.sql.DriverManager
import java.sql.SQLException
import java.util.Date
import kotlin.reflect.typeOf

private const val URL = "jdbc:h2:mem:test1;DB_CLOSE_DELAY=-1;MODE=MariaDB;DATABASE_TO_LOWER=TRUE"
Expand Down Expand Up @@ -303,10 +305,21 @@ class MariadbH2Test {
val df1 = DataFrame.readSqlTable(connection, "table1").cast<Table1MariaDb>()
val result = df1.filter { it[Table1MariaDb::id] == 1 }
result[0][26] shouldBe "textValue1"
val byteArray = "tinyblobValue".toByteArray()
(result[0][22] as Blob).getBytes(1, byteArray.size) contentEquals byteArray

val schema = DataFrame.getSchemaForSqlTable(connection, "table1")
schema.columns["id"]!!.type shouldBe typeOf<Int>()
schema.columns["textcol"]!!.type shouldBe typeOf<String>()
schema.columns["varbinarycol"]!!.type shouldBe typeOf<ByteArray>()
schema.columns["binarycol"]!!.type shouldBe typeOf<ByteArray>()
schema.columns["longblobcol"]!!.type shouldBe typeOf<Blob>()
schema.columns["tinyblobcol"]!!.type shouldBe typeOf<Blob>()
schema.columns["datecol"]!!.type shouldBe typeOf<Date>()
schema.columns["datetimecol"]!!.type shouldBe typeOf<java.sql.Timestamp>()
schema.columns["timestampcol"]!!.type shouldBe typeOf<java.sql.Timestamp>()
schema.columns["timecol"]!!.type shouldBe typeOf<java.sql.Time>()
schema.columns["yearcol"]!!.type shouldBe typeOf<Int>()

val df2 = DataFrame.readSqlTable(connection, "table2").cast<Table2MariaDb>()
val result2 = df2.filter { it[Table2MariaDb::id] == 1 }
Expand Down Expand Up @@ -396,11 +409,11 @@ class MariadbH2Test {
val schema = DataFrame.getSchemaForSqlTable(connection, "table1")

schema.columns["tinyintcol"]!!.type shouldBe typeOf<Int>()
schema.columns["smallintcol"]!!.type shouldBe typeOf<Short?>()
schema.columns["smallintcol"]!!.type shouldBe typeOf<Int?>()
schema.columns["mediumintcol"]!!.type shouldBe typeOf<Int>()
schema.columns["mediumintunsignedcol"]!!.type shouldBe typeOf<Int>()
schema.columns["bigintcol"]!!.type shouldBe typeOf<Long>()
schema.columns["floatcol"]!!.type shouldBe typeOf<Float>()
schema.columns["floatcol"]!!.type shouldBe typeOf<Double>()
schema.columns["doublecol"]!!.type shouldBe typeOf<Double>()
schema.columns["decimalcol"]!!.type shouldBe typeOf<BigDecimal>()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,15 @@ class MSSQLH2Test {
schema.columns["bigintColumn"]!!.type shouldBe typeOf<Long?>()
schema.columns["binaryColumn"]!!.type shouldBe typeOf<ByteArray?>()
schema.columns["bitColumn"]!!.type shouldBe typeOf<Boolean?>()
schema.columns["charColumn"]!!.type shouldBe typeOf<Char?>()
schema.columns["charColumn"]!!.type shouldBe typeOf<String?>()
schema.columns["dateColumn"]!!.type shouldBe typeOf<Date?>()
schema.columns["datetime3Column"]!!.type shouldBe typeOf<java.sql.Timestamp?>()
schema.columns["datetime2Column"]!!.type shouldBe typeOf<java.sql.Timestamp?>()
schema.columns["decimalColumn"]!!.type shouldBe typeOf<BigDecimal?>()
schema.columns["floatColumn"]!!.type shouldBe typeOf<Double?>()
schema.columns["intColumn"]!!.type shouldBe typeOf<Int?>()
schema.columns["moneyColumn"]!!.type shouldBe typeOf<BigDecimal?>()
schema.columns["ncharColumn"]!!.type shouldBe typeOf<Char?>()
schema.columns["ncharColumn"]!!.type shouldBe typeOf<String?>()
schema.columns["ntextColumn"]!!.type shouldBe typeOf<String?>()
schema.columns["numericColumn"]!!.type shouldBe typeOf<BigDecimal?>()
schema.columns["nvarcharColumn"]!!.type shouldBe typeOf<String?>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.math.BigDecimal
import java.sql.Connection
import java.sql.DriverManager
import java.sql.SQLException
import java.util.Date
import kotlin.reflect.typeOf

// NOTE: the names of testing databases should be different to avoid collisions and should not contain the system names itself
Expand Down Expand Up @@ -306,6 +307,15 @@ class MySqlH2Test {
val schema = DataFrame.getSchemaForSqlTable(connection, "table1")
schema.columns["id"]!!.type shouldBe typeOf<Int>()
schema.columns["textcol"]!!.type shouldBe typeOf<String>()
schema.columns["datecol"]!!.type shouldBe typeOf<Date>()
schema.columns["datetimecol"]!!.type shouldBe typeOf<java.sql.Timestamp>()
schema.columns["timestampcol"]!!.type shouldBe typeOf<java.sql.Timestamp>()
schema.columns["timecol"]!!.type shouldBe typeOf<java.sql.Time>()
schema.columns["yearcol"]!!.type shouldBe typeOf<Int>()
schema.columns["varbinarycol"]!!.type shouldBe typeOf<ByteArray>()
schema.columns["binarycol"]!!.type shouldBe typeOf<ByteArray>()
schema.columns["longblobcol"]!!.type shouldBe typeOf<java.sql.Blob>()
schema.columns["tinyblobcol"]!!.type shouldBe typeOf<java.sql.Blob>()

val df2 = DataFrame.readSqlTable(connection, "table2").cast<Table2MySql>()
val result2 = df2.filter { it[Table2MySql::id] == 1 }
Expand Down Expand Up @@ -403,7 +413,7 @@ class MySqlH2Test {
schema.columns["mediumintcol"]!!.type shouldBe typeOf<Int>()
schema.columns["mediumintunsignedcol"]!!.type shouldBe typeOf<Int>()
schema.columns["bigintcol"]!!.type shouldBe typeOf<Long>()
schema.columns["floatcol"]!!.type shouldBe typeOf<Float>()
schema.columns["floatcol"]!!.type shouldBe typeOf<Double>()
schema.columns["doublecol"]!!.type shouldBe typeOf<Double>()
schema.columns["decimalcol"]!!.type shouldBe typeOf<BigDecimal>()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ class PostgresH2Test {
charCol char not null,
dateCol date not null,
doubleCol double precision not null,
integerCol integer
integerCol integer,
intArrayCol integer array,
doubleArrayCol double precision array,
dateArrayCol date array,
textArrayCol text array,
booleanArrayCol boolean array
)
"""
connection.createStatement().execute(createTableStatement.trimIndent())
Expand Down Expand Up @@ -120,8 +125,9 @@ class PostgresH2Test {
bigintCol, smallintCol, bigserialCol, booleanCol,
byteaCol, characterCol, characterNCol, charCol,
dateCol, doubleCol,
integerCol
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
integerCol, intArrayCol,
doubleArrayCol, dateArrayCol, textArrayCol, booleanArrayCol
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""

@Language("SQL")
Expand All @@ -135,6 +141,15 @@ class PostgresH2Test {
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""

val intArray = connection.createArrayOf("INTEGER", arrayOf(1, 2, 3))
val doubleArray = connection.createArrayOf("DOUBLE", arrayOf(1.1, 2.2, 3.3))
val dateArray = connection.createArrayOf(
"DATE",
arrayOf(java.sql.Date.valueOf("2023-08-01"), java.sql.Date.valueOf("2023-08-02")),
)
val textArray = connection.createArrayOf("TEXT", arrayOf("Hello", "World"))
val booleanArray = connection.createArrayOf("BOOLEAN", arrayOf(true, false, true))

connection.prepareStatement(insertData1).use { st ->
// Insert data into table1
for (i in 1..3) {
Expand All @@ -149,6 +164,11 @@ class PostgresH2Test {
st.setDate(9, java.sql.Date.valueOf("2023-08-01"))
st.setDouble(10, 12.34)
st.setInt(11, 12345 * i)
st.setArray(12, intArray)
st.setArray(13, doubleArray)
st.setArray(14, dateArray)
st.setArray(15, textArray)
st.setArray(16, booleanArray)
st.executeUpdate()
}
}
Expand Down Expand Up @@ -191,11 +211,21 @@ class PostgresH2Test {

result[0][0] shouldBe 1
result[0][8] shouldBe "A"
result[0][12] shouldBe arrayOf(1, 2, 3)
result[0][13] shouldBe arrayOf(1.1, 2.2, 3.3)
result[0][14] shouldBe arrayOf(java.sql.Date.valueOf("2023-08-01"), java.sql.Date.valueOf("2023-08-02"))
result[0][15] shouldBe arrayOf("Hello", "World")
result[0][16] shouldBe arrayOf(true, false, true)

val schema = DataFrame.getSchemaForSqlTable(connection, tableName1)
schema.columns["id"]!!.type shouldBe typeOf<Int>()
schema.columns["integercol"]!!.type shouldBe typeOf<Int?>()
schema.columns["smallintcol"]!!.type shouldBe typeOf<Int>()
schema.columns["intarraycol"]!!.type.classifier shouldBe kotlin.Array::class
schema.columns["doublearraycol"]!!.type.classifier shouldBe kotlin.Array::class
schema.columns["datearraycol"]!!.type.classifier shouldBe kotlin.Array::class
schema.columns["textarraycol"]!!.type.classifier shouldBe kotlin.Array::class
schema.columns["booleanarraycol"]!!.type.classifier shouldBe kotlin.Array::class

val tableName2 = "table2"
val df2 = DataFrame.readSqlTable(connection, tableName2).cast<Table2>()
Expand Down
Loading
Loading