Skip to content

Fixes the schema generation for JDBC integration #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public annotation class CsvOptions(
public annotation class JdbcOptions(
public val user: String = "", // TODO: I'm not sure about the default parameters
public val password: String = "", // TODO: I'm not sure about the default parameters)
public val tableName: String = "",
public val sqlQuery: String = ""
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public annotation class CsvOptions(
public annotation class JdbcOptions(
public val user: String = "", // TODO: I'm not sure about the default parameters
public val password: String = "", // TODO: I'm not sure about the default parameters)
public val tableName: String = "",
public val sqlQuery: String = ""
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ import org.jetbrains.kotlinx.dataframe.io.TableMetadata
* @property [dbTypeInJdbcUrl] The name of the database as specified in the JDBC URL.
*/
public abstract class DbType(public val dbTypeInJdbcUrl: String) {


/**
* Represents the JDBC driver class name for a given database type.
*
* NOTE: It's important for usage in dataframe-gradle-plugin for force class loading.
*
* @return The JDBC driver class name as a [String].
*/
public abstract val driverClassName: String

/**
* Converts the data from the given [ResultSet] into the specified [TableColumnMetadata] type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import kotlin.reflect.typeOf
* NOTE: All date and timestamp related types are converted to String to avoid java.sql.* types.
*/
public object H2 : DbType("h2") {
override val driverClassName: String
get() = "org.h2.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.reflect.typeOf
* and to generate the corresponding column schema.
*/
public object MariaDb : DbType("mariadb") {
override val driverClassName: String
get() = "org.mariadb.jdbc.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import kotlin.reflect.typeOf
* and to generate the corresponding column schema.
*/
public object MySql : DbType("mysql") {
override val driverClassName: String
get() = "com.mysql.jdbc.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import kotlin.reflect.typeOf
* and to generate the corresponding column schema.
*/
public object PostgreSql : DbType("postgresql") {
override val driverClassName: String
get() = "org.postgresql.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.reflect.typeOf
* and to generate the corresponding column schema.
*/
public object Sqlite : DbType("sqlite") {
override val driverClassName: String
get() = "org.sqlite.JDBC"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import java.sql.SQLException
* @return the corresponding [DbType].
* @throws RuntimeException if the url is null.
*/
public fun extractDBTypeFromURL(url: String?): DbType {
public fun extractDBTypeFromUrl(url: String?): DbType {
if (url != null) {
return when {
H2.dbTypeInJdbcUrl in url -> H2
Expand All @@ -24,3 +24,14 @@ public fun extractDBTypeFromURL(url: String?): DbType {
throw SQLException("Database URL could not be null. The existing value is $url")
}
}

/**
* Retrieves the driver class name from the given JDBC URL.
*
* @param [url] The JDBC URL to extract the driver class name from.
* @return The driver class name as a [String].
*/
public fun driverClassNameFromUrl(url: String): String {
val dbType = extractDBTypeFromUrl(url)
return dbType.driverClassName
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.io.db.DbType
import org.jetbrains.kotlinx.dataframe.io.db.H2
import org.jetbrains.kotlinx.dataframe.io.db.Sqlite
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromURL
import org.jetbrains.kotlinx.dataframe.io.db.extractDBTypeFromUrl
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema

private val logger = KotlinLogging.logger {}
Expand Down Expand Up @@ -115,7 +113,7 @@ public fun DataFrame.Companion.readSqlTable(connection: Connection, tableName: S
if (limit > 0) preparedQuery += " LIMIT $limit"

val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

connection.createStatement().use { st ->
logger.debug { "Connection with url:${url} is established successfully." }
Expand Down Expand Up @@ -182,7 +180,7 @@ public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: St
*/
public fun DataFrame.Companion.readSqlQuery(connection: Connection, sqlQuery: String, limit: Int): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

var internalSqlQuery = sqlQuery
if (limit > 0) internalSqlQuery += " LIMIT $limit"
Expand Down Expand Up @@ -247,7 +245,7 @@ public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: C
*/
public fun DataFrame.Companion.readResultSet(resultSet: ResultSet, connection: Connection, limit: Int): AnyFrame {
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

return readResultSet(resultSet, dbType, limit)
}
Expand Down Expand Up @@ -300,7 +298,7 @@ public fun DataFrame.Companion.readAllSqlTables(connection: Connection): List<An
public fun DataFrame.Companion.readAllSqlTables(connection: Connection, limit: Int): List<AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

// exclude a system and other tables without data, but it looks like it supported badly for many databases
val tables = metaData.getTables(null, null, null, arrayOf("TABLE"))
Expand Down Expand Up @@ -348,7 +346,7 @@ public fun DataFrame.Companion.getSchemaForSqlTable(
tableName: String
): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

connection.createStatement().use {
logger.debug { "Connection with url:${connection.metaData.url} is established successfully." }
Expand Down Expand Up @@ -383,7 +381,7 @@ public fun DataFrame.Companion.getSchemaForSqlQuery(dbConfig: DatabaseConfigurat
*/
public fun DataFrame.Companion.getSchemaForSqlQuery(connection: Connection, sqlQuery: String): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

connection.createStatement().use { st ->
st.executeQuery(sqlQuery).use { rs ->
Expand Down Expand Up @@ -419,7 +417,7 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, dbTyp
*/
public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, connection: Connection): DataFrameSchema {
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

val tableColumns = getTableColumnsMetadata(resultSet)
return buildSchemaByTableColumns(tableColumns, dbType)
Expand All @@ -446,7 +444,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfig
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List<DataFrameSchema> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromURL(url)
val dbType = extractDBTypeFromUrl(url)

val tableTypes = arrayOf("TABLE")
// exclude system and other tables without data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import org.jetbrains.kotlinx.dataframe.impl.codeGen.toStandaloneSnippet
import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader
import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlDfReader
import java.io.File
import java.lang.RuntimeException
import java.net.URL
import java.nio.file.Paths
import java.sql.Connection
import java.sql.DriverManager
import org.jetbrains.kotlinx.dataframe.io.ArrowFeather
import org.jetbrains.kotlinx.dataframe.io.CSV
Expand All @@ -30,6 +32,7 @@ import org.jetbrains.kotlinx.dataframe.io.TSV
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlQuery
import org.jetbrains.kotlinx.dataframe.io.getSchemaForSqlTable
import org.jetbrains.kotlinx.dataframe.io.isURL
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema

abstract class GenerateDataSchemaTask : DefaultTask() {

Expand Down Expand Up @@ -84,9 +87,7 @@ abstract class GenerateDataSchemaTask : DefaultTask() {
if (rawUrl.startsWith("jdbc")) {
val connection = DriverManager.getConnection(rawUrl, jdbcOptions.user, jdbcOptions.password)
connection.use {
val schema = if(jdbcOptions.sqlQuery.isBlank())
DataFrame.getSchemaForSqlTable(connection, interfaceName.get())
else DataFrame.getSchemaForSqlQuery(connection, jdbcOptions.sqlQuery)
val schema = generateSchemaByJdbcOptions(jdbcOptions, connection)

val codeGenerator = CodeGenerator.create(useFqNames = false)

Expand Down Expand Up @@ -174,6 +175,23 @@ abstract class GenerateDataSchemaTask : DefaultTask() {
}
}

private fun generateSchemaByJdbcOptions(
jdbcOptions: JdbcOptionsDsl,
connection: Connection
): DataFrameSchema {
logger.debug("Table name: ${jdbcOptions.tableName}")
logger.debug("SQL query: ${jdbcOptions.sqlQuery}")

return if (jdbcOptions.tableName.isNotBlank())
DataFrame.getSchemaForSqlTable(connection, jdbcOptions.tableName)
else if(jdbcOptions.sqlQuery.isNotBlank())
DataFrame.getSchemaForSqlQuery(connection, jdbcOptions.sqlQuery)
else throw RuntimeException("Table name: ${jdbcOptions.tableName}, " +
"SQL query: ${jdbcOptions.sqlQuery} both are empty! " +
"Populate 'tableName' or 'sqlQuery' in jdbcOptions with value to generate schema " +
"for SQL table or result of SQL query!")
}

private fun stringOf(data: Any): String =
when (data) {
is File -> data.absolutePath
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,6 @@ data class JsonOptionsDsl(
data class JdbcOptionsDsl(
var user: String = "", // TODO: I'm not sure about the default parameters
var password: String = "", // TODO: I'm not sure about the default parameters
var tableName: String = "",
var sqlQuery: String = ""
) : Serializable
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlCodeGenReader
import org.jetbrains.kotlinx.dataframe.impl.codeGen.urlDfReader
import org.jetbrains.kotlinx.dataframe.io.*
import java.io.File
import java.lang.RuntimeException
import java.net.MalformedURLException
import java.net.URL
import java.sql.Connection
import java.sql.DriverManager
import org.jetbrains.kotlinx.dataframe.io.db.H2
import org.jetbrains.kotlinx.dataframe.io.db.MariaDb
import org.jetbrains.kotlinx.dataframe.io.db.MySql
import org.jetbrains.kotlinx.dataframe.io.db.PostgreSql
import org.jetbrains.kotlinx.dataframe.io.db.Sqlite
import org.jetbrains.kotlinx.dataframe.io.db.driverClassNameFromUrl
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema

@OptIn(KspExperimental::class)
class DataSchemaGenerator(
Expand Down Expand Up @@ -159,7 +168,8 @@ class DataSchemaGenerator(
if (importStatement.isJdbc) {
val url = importStatement.dataSource.pathRepresentation

if(url.contains("h2")) Class.forName("org.h2.Driver")
// Force classloading
Class.forName(driverClassNameFromUrl(url))

val connection = DriverManager.getConnection(
url,
Expand All @@ -168,9 +178,7 @@ class DataSchemaGenerator(
)

connection.use {
val schema = if(importStatement.jdbcOptions.sqlQuery.isBlank())
DataFrame.getSchemaForSqlTable(connection, importStatement.name)
else DataFrame.getSchemaForSqlQuery(connection, importStatement.jdbcOptions.sqlQuery)
val schema = generateSchemaForImport(importStatement, connection)

val codeGenerator = CodeGenerator.create(useFqNames = false)

Expand Down Expand Up @@ -257,5 +265,22 @@ class DataSchemaGenerator(
it.write(code)
}
}

private fun generateSchemaForImport(
importStatement: ImportDataSchemaStatement,
connection: Connection
): DataFrameSchema {
logger.info("Table name: ${importStatement.jdbcOptions.tableName}")
logger.info("SQL query: ${importStatement.jdbcOptions.sqlQuery}")

return if (importStatement.jdbcOptions.tableName.isNotBlank())
DataFrame.getSchemaForSqlTable(connection, importStatement.jdbcOptions.tableName)
else if(importStatement.jdbcOptions.sqlQuery.isNotBlank())
DataFrame.getSchemaForSqlQuery(connection, importStatement.jdbcOptions.sqlQuery)
else throw RuntimeException("Table name: ${importStatement.jdbcOptions.tableName}, " +
"SQL query: ${importStatement.jdbcOptions.sqlQuery} both are empty! " +
"Populate 'tableName' or 'sqlQuery' in jdbcOptions with value to generate schema " +
"for SQL table or result of SQL query!")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class DataFrameJdbcSymbolProcessorTest {
package test

import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema
import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.cast
Expand All @@ -126,11 +127,16 @@ class DataFrameJdbcSymbolProcessorTest {
SourceFile.kotlin(
"MySources.kt",
"""
@file:ImportDataSchema(name = "Customer", path = "$CONNECTION_URL")

@file:ImportDataSchema(
"Customer",
"$CONNECTION_URL",
jdbcOptions = JdbcOptions("", "", tableName = "Customer")
)

package test

import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema
import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.cast
Expand Down Expand Up @@ -161,11 +167,16 @@ class DataFrameJdbcSymbolProcessorTest {
SourceFile.kotlin(
"MySources.kt",
"""
@file:ImportDataSchema(name = "Customer", path = "$CONNECTION_URL")
@file:ImportDataSchema(
"Customer",
"$CONNECTION_URL",
jdbcOptions = JdbcOptions("", "", tableName = "Customer")
)

package test

import org.jetbrains.kotlinx.dataframe.annotations.ImportDataSchema
import org.jetbrains.kotlinx.dataframe.annotations.JdbcOptions
import org.jetbrains.kotlinx.dataframe.api.filter
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.cast
Expand Down