Skip to content

Ported the fix for JDBC integration from the 0.12.1 branch #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ This table shows the mapping between main library component versions and minimum
| 0.11.0 | 8 | 1.8.20 | 0.11.0-358 | 3.0.0 | 11.0.0 |
| 0.11.1 | 8 | 1.8.20 | 0.11.0-358 | 3.0.0 | 11.0.0 |
| 0.12.0 | 8 | 1.9.0 | 0.11.0-358 | 3.0.0 | 11.0.0 |
| 0.12.1 | 8 | 1.9.0 | 0.11.0-358 | 3.0.0 | 11.0.0 |

## Usage example

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ internal fun <T, C> RenameClause<T, C>.renameImpl(newNames: Array<out String>):
internal fun <T, C> RenameClause<T, C>.renameImpl(transform: (ColumnWithPath<C>) -> String): DataFrame<T> {
// get all selected columns and their paths
val selectedColumnsWithPath = df.getColumnsWithPaths(columns)
.associateBy { it.data }
.associateBy { it.path }
// gather a tree of all columns where the nodes will be renamed
val tree = df.getColumnsWithPaths { all().rec() }.collectTree()

// perform rename in nodes
tree.allChildrenNotNull().forEach { node ->
tree.allChildrenNotNull().map { it to it.pathFromRoot() }.forEach { (node, originalPath) ->
// Check if the current node/column is a selected column and, if so, get its ColumnWithPath
val column = selectedColumnsWithPath[node.data] ?: return@forEach
val column = selectedColumnsWithPath[originalPath] ?: return@forEach
// Use the found selected ColumnWithPath to query for the new name
val newColumnName = transform(column)
node.name = newColumnName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.KType

/**
* The `DbType` class represents a database type used for reading dataframe from the database.
Expand All @@ -22,19 +23,10 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
*/
public abstract val driverClassName: String

/**
* Converts the data from the given [ResultSet] into the specified [TableColumnMetadata] type.
*
* @param rs The [ResultSet] containing the data to be converted.
* @param tableColumnMetadata The [TableColumnMetadata] representing the target type of the conversion.
* @return The converted data as an instance of [Any].
*/
public abstract fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any?

/**
* Returns a [ColumnSchema] produced from [tableColumnMetadata].
*/
public abstract fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema
public abstract fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema?

/**
* Checks if the given table name is a system table for the specified database type.
Expand All @@ -52,4 +44,12 @@ public abstract class DbType(public val dbTypeInJdbcUrl: String) {
* @return the TableMetadata object representing the table metadata.
*/
public abstract fun buildTableMetadata(tables: ResultSet): TableMetadata

/**
* Converts SQL data type to a Kotlin data type.
*
* @param [tableColumnMetadata] The metadata of the table column.
* @return The corresponding Kotlin data type, or null if no mapping is found.
*/
public abstract fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType?
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.Locale
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.typeOf
import kotlin.reflect.KType

/**
* Represents the H2 database type.
Expand All @@ -21,71 +19,8 @@ public object H2 : DbType("h2") {
override val driverClassName: String
get() = "org.h2.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
"CHARACTER", "CHAR" -> rs.getString(name)
"CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> rs.getString(name)
"CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> rs.getString(name)
"MEDIUMTEXT" -> rs.getString(name)
"VARCHAR_IGNORECASE" -> rs.getString(name)
"BINARY" -> rs.getBytes(name)
"BINARY VARYING", "VARBINARY" -> rs.getBytes(name)
"BINARY LARGE OBJECT", "BLOB" -> rs.getBytes(name)
"BOOLEAN" -> rs.getBoolean(name)
"TINYINT" -> rs.getByte(name)
"SMALLINT" -> rs.getShort(name)
"INTEGER", "INT" -> rs.getInt(name)
"BIGINT" -> rs.getLong(name)
"NUMERIC", "DECIMAL", "DEC" -> rs.getFloat(name) // not a BigDecimal
"REAL", "FLOAT" -> rs.getFloat(name)
"DOUBLE PRECISION" -> rs.getDouble(name)
"DECFLOAT" -> rs.getDouble(name)
"DATE" -> rs.getDate(name).toString()
"TIME" -> rs.getTime(name).toString()
"TIME WITH TIME ZONE" -> rs.getTime(name).toString()
"TIMESTAMP" -> rs.getTimestamp(name).toString()
"TIMESTAMP WITH TIME ZONE" -> rs.getTimestamp(name).toString()
"INTERVAL" -> rs.getObject(name).toString()
"JAVA_OBJECT" -> rs.getObject(name)
"ENUM" -> rs.getString(name)
"JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462
"UUID" -> rs.getString(name)
else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName}")
}
}

override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema {
return when (tableColumnMetadata.sqlTypeName) {
"CHARACTER", "CHAR" -> ColumnSchema.Value(typeOf<String>())
"CHARACTER VARYING", "CHAR VARYING", "VARCHAR" -> ColumnSchema.Value(typeOf<String>())
"CHARACTER LARGE OBJECT", "CHAR LARGE OBJECT", "CLOB" -> ColumnSchema.Value(typeOf<String>())
"MEDIUMTEXT" -> ColumnSchema.Value(typeOf<String>())
"VARCHAR_IGNORECASE" -> ColumnSchema.Value(typeOf<String>())
"BINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"BINARY VARYING", "VARBINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"BINARY LARGE OBJECT", "BLOB" -> ColumnSchema.Value(typeOf<ByteArray>())
"BOOLEAN" -> ColumnSchema.Value(typeOf<Boolean>())
"TINYINT" -> ColumnSchema.Value(typeOf<Byte>())
"SMALLINT" -> ColumnSchema.Value(typeOf<Short>())
"INTEGER", "INT" -> ColumnSchema.Value(typeOf<Int>())
"BIGINT" -> ColumnSchema.Value(typeOf<Long>())
"NUMERIC", "DECIMAL", "DEC" -> ColumnSchema.Value(typeOf<Float>())
"REAL", "FLOAT" -> ColumnSchema.Value(typeOf<Float>())
"DOUBLE PRECISION" -> ColumnSchema.Value(typeOf<Double>())
"DECFLOAT" -> ColumnSchema.Value(typeOf<Double>())
"DATE" -> ColumnSchema.Value(typeOf<String>())
"TIME" -> ColumnSchema.Value(typeOf<String>())
"TIME WITH TIME ZONE" -> ColumnSchema.Value(typeOf<String>())
"TIMESTAMP" -> ColumnSchema.Value(typeOf<String>())
"TIMESTAMP WITH TIME ZONE" -> ColumnSchema.Value(typeOf<String>())
"INTERVAL" -> ColumnSchema.Value(typeOf<String>())
"JAVA_OBJECT" -> ColumnSchema.Value(typeOf<Any>())
"ENUM" -> ColumnSchema.Value(typeOf<String>())
"JSON" -> ColumnSchema.Value(typeOf<String>()) // TODO: https://github.com/Kotlin/dataframe/issues/462
"UUID" -> ColumnSchema.Value(typeOf<String>())
else -> throw IllegalArgumentException("Unsupported H2 type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}")
}
override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
Expand All @@ -99,4 +34,8 @@ public object H2 : DbType("h2") {
tables.getString("TABLE_SCHEM"),
tables.getString("TABLE_CAT"))
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.typeOf
import kotlin.reflect.KType

/**
* Represents the MariaDb database type.
Expand All @@ -16,73 +16,8 @@ public object MariaDb : DbType("mariadb") {
override val driverClassName: String
get() = "org.mariadb.jdbc.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
"BIT" -> rs.getBytes(name)
"TINYINT" -> rs.getInt(name)
"SMALLINT" -> rs.getInt(name)
"MEDIUMINT"-> rs.getInt(name)
"MEDIUMINT UNSIGNED" -> rs.getLong(name)
"INTEGER", "INT" -> rs.getInt(name)
"INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name)
"BIGINT" -> rs.getLong(name)
"FLOAT" -> rs.getFloat(name)
"DOUBLE" -> rs.getDouble(name)
"DECIMAL" -> rs.getBigDecimal(name)
"DATE" -> rs.getDate(name).toString()
"DATETIME" -> rs.getTimestamp(name).toString()
"TIMESTAMP" -> rs.getTimestamp(name).toString()
"TIME"-> rs.getTime(name).toString()
"YEAR" -> rs.getDate(name).toString()
"VARCHAR", "CHAR" -> rs.getString(name)
"BINARY" -> rs.getBytes(name)
"VARBINARY" -> rs.getBytes(name)
"TINYBLOB"-> rs.getBytes(name)
"BLOB"-> rs.getBytes(name)
"MEDIUMBLOB" -> rs.getBytes(name)
"LONGBLOB" -> rs.getBytes(name)
"TEXT" -> rs.getString(name)
"MEDIUMTEXT" -> rs.getString(name)
"LONGTEXT" -> rs.getString(name)
"ENUM" -> rs.getString(name)
"SET" -> rs.getString(name)
else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName}")
}
}

override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema {
return when (tableColumnMetadata.sqlTypeName) {
"BIT" -> ColumnSchema.Value(typeOf<ByteArray>())
"TINYINT" -> ColumnSchema.Value(typeOf<Int>())
"SMALLINT" -> ColumnSchema.Value(typeOf<Int>())
"MEDIUMINT"-> ColumnSchema.Value(typeOf<Int>())
"MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf<Long>())
"INTEGER", "INT" -> ColumnSchema.Value(typeOf<Int>())
"INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf<Long>())
"BIGINT" -> ColumnSchema.Value(typeOf<Long>())
"FLOAT" -> ColumnSchema.Value(typeOf<Float>())
"DOUBLE" -> ColumnSchema.Value(typeOf<Double>())
"DECIMAL" -> ColumnSchema.Value(typeOf<Double>())
"DATE" -> ColumnSchema.Value(typeOf<String>())
"DATETIME" -> ColumnSchema.Value(typeOf<String>())
"TIMESTAMP" -> ColumnSchema.Value(typeOf<String>())
"TIME"-> ColumnSchema.Value(typeOf<String>())
"YEAR" -> ColumnSchema.Value(typeOf<String>())
"VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf<String>())
"BINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"VARBINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"TINYBLOB"-> ColumnSchema.Value(typeOf<ByteArray>())
"BLOB"-> ColumnSchema.Value(typeOf<ByteArray>())
"MEDIUMBLOB" -> ColumnSchema.Value(typeOf<ByteArray>())
"LONGBLOB" -> ColumnSchema.Value(typeOf<ByteArray>())
"TEXT" -> ColumnSchema.Value(typeOf<String>())
"MEDIUMTEXT" -> ColumnSchema.Value(typeOf<String>())
"LONGTEXT" -> ColumnSchema.Value(typeOf<String>())
"ENUM" -> ColumnSchema.Value(typeOf<String>())
"SET" -> ColumnSchema.Value(typeOf<String>())
else -> throw IllegalArgumentException("Unsupported MariaDB type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}")
}
override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
Expand All @@ -95,4 +30,8 @@ public object MariaDb : DbType("mariadb") {
tables.getString("table_schem"),
tables.getString("table_cat"))
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import org.jetbrains.kotlinx.dataframe.io.TableColumnMetadata
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import java.sql.ResultSet
import java.util.Locale
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.io.TableMetadata
import kotlin.reflect.typeOf
import kotlin.reflect.KType

/**
* Represents the MySql database type.
Expand All @@ -19,79 +17,8 @@ public object MySql : DbType("mysql") {
override val driverClassName: String
get() = "com.mysql.jdbc.Driver"

override fun convertDataFromResultSet(rs: ResultSet, tableColumnMetadata: TableColumnMetadata): Any? {
val name = tableColumnMetadata.name
return when (tableColumnMetadata.sqlTypeName) {
"BIT" -> rs.getBytes(name)
"TINYINT" -> rs.getInt(name)
"SMALLINT" -> rs.getInt(name)
"MEDIUMINT"-> rs.getInt(name)
"MEDIUMINT UNSIGNED" -> rs.getLong(name)
"INTEGER", "INT" -> rs.getInt(name)
"INTEGER UNSIGNED", "INT UNSIGNED" -> rs.getLong(name)
"BIGINT" -> rs.getLong(name)
"FLOAT" -> rs.getFloat(name)
"DOUBLE" -> rs.getDouble(name)
"DECIMAL" -> rs.getBigDecimal(name)
"DATE" -> rs.getDate(name).toString()
"DATETIME" -> rs.getTimestamp(name).toString()
"TIMESTAMP" -> rs.getTimestamp(name).toString()
"TIME"-> rs.getTime(name).toString()
"YEAR" -> rs.getDate(name).toString()
"VARCHAR", "CHAR" -> rs.getString(name)
"BINARY" -> rs.getBytes(name)
"VARBINARY" -> rs.getBytes(name)
"TINYBLOB"-> rs.getBytes(name)
"BLOB"-> rs.getBytes(name)
"MEDIUMBLOB" -> rs.getBytes(name)
"LONGBLOB" -> rs.getBytes(name)
"TEXT" -> rs.getString(name)
"MEDIUMTEXT" -> rs.getString(name)
"LONGTEXT" -> rs.getString(name)
"ENUM" -> rs.getString(name)
"SET" -> rs.getString(name)
// special mysql types
"JSON" -> rs.getString(name) // TODO: https://github.com/Kotlin/dataframe/issues/462
"GEOMETRY" -> rs.getBytes(name)
else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName}")
}
}

override fun toColumnSchema(tableColumnMetadata: TableColumnMetadata): ColumnSchema {
return when (tableColumnMetadata.sqlTypeName) {
"BIT" -> ColumnSchema.Value(typeOf<ByteArray>())
"TINYINT" -> ColumnSchema.Value(typeOf<Int>())
"SMALLINT" -> ColumnSchema.Value(typeOf<Int>())
"MEDIUMINT"-> ColumnSchema.Value(typeOf<Int>())
"MEDIUMINT UNSIGNED" -> ColumnSchema.Value(typeOf<Long>())
"INTEGER", "INT" -> ColumnSchema.Value(typeOf<Int>())
"INTEGER UNSIGNED", "INT UNSIGNED" -> ColumnSchema.Value(typeOf<Long>())
"BIGINT" -> ColumnSchema.Value(typeOf<Long>())
"FLOAT" -> ColumnSchema.Value(typeOf<Float>())
"DOUBLE" -> ColumnSchema.Value(typeOf<Double>())
"DECIMAL" -> ColumnSchema.Value(typeOf<Double>())
"DATE" -> ColumnSchema.Value(typeOf<String>())
"DATETIME" -> ColumnSchema.Value(typeOf<String>())
"TIMESTAMP" -> ColumnSchema.Value(typeOf<String>())
"TIME"-> ColumnSchema.Value(typeOf<String>())
"YEAR" -> ColumnSchema.Value(typeOf<String>())
"VARCHAR", "CHAR" -> ColumnSchema.Value(typeOf<String>())
"BINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"VARBINARY" -> ColumnSchema.Value(typeOf<ByteArray>())
"TINYBLOB"-> ColumnSchema.Value(typeOf<ByteArray>())
"BLOB"-> ColumnSchema.Value(typeOf<ByteArray>())
"MEDIUMBLOB" -> ColumnSchema.Value(typeOf<ByteArray>())
"LONGBLOB" -> ColumnSchema.Value(typeOf<ByteArray>())
"TEXT" -> ColumnSchema.Value(typeOf<String>())
"MEDIUMTEXT" -> ColumnSchema.Value(typeOf<String>())
"LONGTEXT" -> ColumnSchema.Value(typeOf<String>())
"ENUM" -> ColumnSchema.Value(typeOf<String>())
"SET" -> ColumnSchema.Value(typeOf<String>())
// special mysql types
"JSON" -> ColumnSchema.Value(typeOf<ColumnGroup<DataRow<String>>>()) // TODO: https://github.com/Kotlin/dataframe/issues/462
"GEOMETRY" -> ColumnSchema.Value(typeOf<ByteArray>())
else -> throw IllegalArgumentException("Unsupported MySQL type: ${tableColumnMetadata.sqlTypeName} for column ${tableColumnMetadata.name}")
}
override fun convertSqlTypeToColumnSchemaValue(tableColumnMetadata: TableColumnMetadata): ColumnSchema? {
return null
}

override fun isSystemTable(tableMetadata: TableMetadata): Boolean {
Expand All @@ -116,4 +43,8 @@ public object MySql : DbType("mysql") {
tables.getString("table_schem"),
tables.getString("table_cat"))
}

override fun convertSqlTypeToKType(tableColumnMetadata: TableColumnMetadata): KType? {
return null
}
}
Loading