Skip to content

Adding contracts for Anycol.isValueColumn etc. for smart-casting #882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
@file:OptIn(ExperimentalContracts::class)

package org.jetbrains.kotlinx.dataframe.api

import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import org.jetbrains.kotlinx.dataframe.impl.isNothing
import org.jetbrains.kotlinx.dataframe.impl.projectTo
import org.jetbrains.kotlinx.dataframe.type
import org.jetbrains.kotlinx.dataframe.typeClass
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.KTypeProjection
import kotlin.reflect.full.isSubclassOf
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.typeOf

public fun AnyCol.isColumnGroup(): Boolean = kind() == ColumnKind.Group
public fun AnyCol.isColumnGroup(): Boolean {
contract { returns(true) implies (this@isColumnGroup is ColumnGroup<*>) }
return kind() == ColumnKind.Group
}

public fun AnyCol.isFrameColumn(): Boolean = kind() == ColumnKind.Frame
public fun AnyCol.isFrameColumn(): Boolean {
contract { returns(true) implies (this@isFrameColumn is FrameColumn<*>) }
return kind() == ColumnKind.Frame
}

public fun AnyCol.isValueColumn(): Boolean = kind() == ColumnKind.Value
public fun AnyCol.isValueColumn(): Boolean {
contract { returns(true) implies (this@isValueColumn is ValueColumn<*>) }
return kind() == ColumnKind.Value
}

public fun AnyCol.isSubtypeOf(type: KType): Boolean =
this.type.isSubtypeOf(type) && (!this.type.isMarkedNullable || type.isMarkedNullable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ public interface ColGroupsColumnsSelectionDsl {
@Suppress("UNCHECKED_CAST")
internal fun ColumnsResolver<*>.columnGroupsInternal(
filter: (ColumnGroup<*>) -> Boolean,
): TransformableColumnSet<AnyRow> =
colsInternal { it.isColumnGroup() && filter(it.asColumnGroup()) } as TransformableColumnSet<AnyRow>
): TransformableColumnSet<AnyRow> = colsInternal { it.isColumnGroup() && filter(it) } as TransformableColumnSet<AnyRow>

// endregion
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ internal fun <C, R> SingleColumn<DataRow<C>>.selectInternal(selector: ColumnsSel
"Column ${col.path} is not a ColumnGroup and can thus not be selected from."
}

col.asColumnGroup()
.getColumnsWithPaths(selector as ColumnsSelector<*, R>)
col.getColumnsWithPaths(selector as ColumnsSelector<*, R>)
.map { it.changePath(col.path + it.path) }
} ?: emptyList()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,6 @@ public interface ValueColsColumnsSelectionDsl {
* @return A [TransformableColumnSet] containing the value columns that satisfy the filter.
*/
internal fun ColumnsResolver<*>.valueColumnsInternal(filter: (ValueColumn<*>) -> Boolean): TransformableColumnSet<*> =
colsInternal { it.isValueColumn() && filter(it.asValueColumn()) }
colsInternal { it.isValueColumn() && filter(it) }

// endregion
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import org.jetbrains.kotlinx.dataframe.ColumnSelector
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
Expand Down Expand Up @@ -78,7 +77,9 @@ internal open class DataFrameReceiver<T>(
DataColumn.createColumnGroup("", df).addPath(emptyPath())

override fun columns() =
df.columns().map { if (it.isColumnGroup()) ColumnGroupWithParent(null, it.asColumnGroup()) else it }
df.columns().map {
if (it.isColumnGroup()) ColumnGroupWithParent(null, it) else it
}

override fun columnNames() = df.columnNames()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import org.jetbrains.kotlinx.dataframe.api.Infer
import org.jetbrains.kotlinx.dataframe.api.add
import org.jetbrains.kotlinx.dataframe.api.all
import org.jetbrains.kotlinx.dataframe.api.allNulls
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.concat
import org.jetbrains.kotlinx.dataframe.api.convertTo
import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.isEmpty
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.api.name
Expand Down Expand Up @@ -197,14 +197,12 @@ internal fun AnyFrame.convertToImpl(

else -> originalColumn
}
require(column.kind == ColumnKind.Group) {
require(column.isColumnGroup()) {
"Column `${column.name}` is ${column.kind} and can not be converted to `ColumnGroup`"
}
val columnGroup = column.asColumnGroup()

DataColumn.createColumnGroup(
name = column.name(),
df = columnGroup.convertToSchema(
df = column.convertToSchema(
schema = (targetSchema as ColumnSchema.Group).schema,
path = columnPath,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import org.jetbrains.kotlinx.dataframe.api.JoinDsl
import org.jetbrains.kotlinx.dataframe.api.JoinType
import org.jetbrains.kotlinx.dataframe.api.allowLeftNulls
import org.jetbrains.kotlinx.dataframe.api.allowRightNulls
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.getColumnsWithPaths
import org.jetbrains.kotlinx.dataframe.api.indices
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
Expand Down Expand Up @@ -86,9 +85,11 @@ internal fun <A, B> DataFrame<A>.joinImpl(
val leftCol = leftJoinColumns[i]
val rightCol = rightJoinColumns[i]
if (leftCol.isColumnGroup() && rightCol.isColumnGroup()) {
val leftColumns = getColumnsWithPaths { leftCol.asColumnGroup().colsAtAnyDepth { !it.isColumnGroup() } }
val leftColumns = getColumnsWithPaths {
leftCol.colsAtAnyDepth { !it.isColumnGroup() }
}
val rightColumns = other.getColumnsWithPaths {
rightCol.asColumnGroup().colsAtAnyDepth { !it.isColumnGroup() }
rightCol.colsAtAnyDepth { !it.isColumnGroup() }
}

val leftPrefixLength = leftCol.path.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ internal fun <T> DataFrame<T>.parseImpl(options: ParserOptions?, columns: Column
it.isFrameColumn() -> it.cast<AnyFrame?>().parse(options)

it.isColumnGroup() ->
it.asColumnGroup()
.parse(options) { all() }
it.parse(options) { all() }
.asColumnGroup(it.name())
.asDataColumn()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.jetbrains.kotlinx.dataframe.impl.columns

import org.jetbrains.kotlinx.dataframe.AnyRow
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
import org.jetbrains.kotlinx.dataframe.api.toPath
Expand Down Expand Up @@ -29,7 +28,7 @@ internal class ColumnAccessorImpl<T>(val path: ColumnPath) : ColumnAccessor<T> {
"Column '${path.subList(0, i + 1).joinToString(".")}' is not a column group.",
)
} else {
col.asColumnGroup()
col
}
}
// resolve the last column of the path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,10 @@ internal fun Iterable<ColumnWithPath<*>>.flattenRecursively(): List<ColumnWithPa
cols.forEach {
result.add(it)
val path = it.path
if (it.data.isColumnGroup()) {
val data = it.data
if (data.isColumnGroup()) {
flattenRecursively(
it.data.asColumnGroup()
.columns()
.map { it.addPath(path + it.name()) },
data.columns().map { it.addPath(path + it.name()) },
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.jetbrains.kotlinx.dataframe.api.ReplaceClause
import org.jetbrains.kotlinx.dataframe.api.Split
import org.jetbrains.kotlinx.dataframe.api.SplitWithTransform
import org.jetbrains.kotlinx.dataframe.api.Update
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataFrame
import org.jetbrains.kotlinx.dataframe.api.columnsCount
import org.jetbrains.kotlinx.dataframe.api.isColumnGroup
Expand Down Expand Up @@ -163,7 +162,7 @@ internal class Integration(private val notebook: Notebook, private val options:
codeGen: ReplCodeGenerator,
): VariableName? =
if (col.isColumnGroup()) {
val codeWithConverter = codeGen.process(col.asColumnGroup().asDataFrame(), property).let { c ->
val codeWithConverter = codeGen.process(col.asDataFrame(), property).let { c ->
CodeWithConverter(c.declarations) { c.converter("$it.asColumnGroup()") }
}
execute(
Expand Down
Loading