Skip to content

Add a constructor to create a nested dataframe from columns inplace #1144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt {
public static final fun columnGroupTyped (Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Lorg/jetbrains/kotlinx/dataframe/columns/ColumnPath;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor;
public static final fun columnOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun columnOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;
public static final fun columnOf ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnGroup;
public static final fun columnOf ([Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/columns/FrameColumn;
public static final fun columnOf ([Lorg/jetbrains/kotlinx/dataframe/columns/BaseColumn;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
public static final fun dataFrameOf (Ljava/lang/Iterable;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
Expand All @@ -1397,6 +1398,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/ConstructorsKt {
public static final fun dataFrameOf ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun dataFrameOf ([Lorg/jetbrains/kotlinx/dataframe/columns/BaseColumn;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun dataFrameOf ([Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)Lorg/jetbrains/kotlinx/dataframe/api/DataFrameBuilder;
public static final fun dataFrameOfColumns ([Lkotlin/Pair;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun emptyDataFrame ()Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun frameColumn ()Lorg/jetbrains/kotlinx/dataframe/api/ColumnDelegate;
public static final fun frameColumn (Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/columns/ColumnAccessor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.ColumnAccessor
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.ColumnPath
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
Expand Down Expand Up @@ -269,6 +270,15 @@ public inline fun <reified T> column(values: Iterable<T>): DataColumn<T> =
allColsMakesColGroup = true,
).forceResolve()

@Refine
@Interpretable("ColumnOfPairs")
public fun columnOf(vararg columns: Pair<String, AnyBaseCol>): ColumnGroup<*> =
dataFrameOf(
columns.map { (name, col) ->
col.rename(name)
},
).asColumnGroup()

// endregion

// region create DataFrame
Expand All @@ -290,6 +300,12 @@ public fun dataFrameOf(columns: Iterable<AnyBaseCol>): DataFrame<*> {
return DataFrameImpl<Unit>(cols, nrow)
}

@Refine
@JvmName("dataFrameOfColumns")
@Interpretable("DataFrameOfPairs")
public fun dataFrameOf(vararg columns: Pair<String, AnyBaseCol>): DataFrame<*> =
dataFrameOf(columns.map { (name, col) -> col.rename(name) })

public fun dataFrameOf(vararg header: ColumnReference<*>): DataFrameBuilder = DataFrameBuilder(header.map { it.name() })

public fun dataFrameOf(vararg columns: AnyBaseCol): DataFrame<*> = dataFrameOf(columns.asIterable())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ class Create : TestBase() {
// SampleEnd
}

@Test
@TransformDataFrameExpressions
fun createNestedDataFrameInplace() {
// SampleStart
// DataFrame with 2 columns and 3 rows
val df = dataFrameOf(
"name" to columnOf(
"firstName" to columnOf("Alice", "Bob", "Charlie"),
"lastName" to columnOf("Cooper", "Dylan", "Daniels"),
),
"age" to columnOf(15, 20, 100),
)
// SampleEnd
}

@Test
@TransformDataFrameExpressions
fun createDataFrameWithFill() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,21 @@ class DataFrameTests : BaseTest() {
df.columns().forEach { col -> col.forEachIndexed { row, value -> value shouldBe row + 1 } }
}

@Test
fun `create nested dataframe inplace`() {
val df = dataFrameOf(
"a" to columnOf("1"),
"b" to columnOf(
"c" to columnOf("2"),
),
"d" to columnOf(dataFrameOf("a")(123)),
"gr" to listOf("1").toDataFrame().asColumnGroup(),
)

df.columnNames() shouldBe listOf("a", "b", "d", "gr")
df.getColumnGroup("gr")["value"].values() shouldBe listOf("1")
}

@Test
fun `get typed column by name`() {
val col = df.getColumn("name").cast<String>()
Expand Down
17 changes: 17 additions & 0 deletions docs/StardustDocs/topics/createDataFrame.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ val df = dataFrameOf(

<!---END-->

Create DataFrame with nested columns inplace:

<!---FUN createNestedDataFrameInplace-->

```kotlin
// DataFrame with 2 columns and 3 rows
val df = dataFrameOf(
"name" to columnOf(
"firstName" to columnOf("Alice", "Bob", "Charlie"),
"lastName" to columnOf("Cooper", "Dylan", "Daniels"),
),
"age" to columnOf(15, 20, 100),
)
```

<!---END-->

<!---FUN createDataFrameFromColumns-->

```kotlin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ class FunctionCallTransformer(
fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall?
}

private val transformers = listOf(GroupByCallTransformer(), DataFrameCallTransformer(), DataRowCallTransformer())
// also update [ReturnTypeBasedReceiverInjector.SCHEMA_TYPES]
private val transformers = listOf(
GroupByCallTransformer(),
DataFrameCallTransformer(),
DataRowCallTransformer(),
ColumnGroupCallTransformer(),
)

override fun intercept(callInfo: CallInfo, symbol: FirNamedFunctionSymbol): CallReturnType? {
val callSiteAnnotations = (callInfo.callSite as? FirAnnotationContainer)?.annotations ?: emptyList()
Expand Down Expand Up @@ -194,6 +200,8 @@ class FunctionCallTransformer(

inner class DataRowCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.DATA_ROW_CLASS_ID)

inner class ColumnGroupCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.COLUM_GROUP_CLASS_ID)

inner class GroupByCallTransformer : CallTransformer {
override fun interceptOrNull(
callInfo: CallInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names

class ReturnTypeBasedReceiverInjector(session: FirSession) : FirExpressionResolutionExtension(session) {
companion object {
private val SCHEMA_TYPES = setOf(
Names.DF_CLASS_ID,
Names.GROUP_BY_CLASS_ID,
Names.DATA_ROW_CLASS_ID,
Names.COLUM_GROUP_CLASS_ID,
)
}

@OptIn(SymbolInternals::class)
override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List<ConeKotlinType> {
val callReturnType = functionCall.resolvedType
return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID, Names.DATA_ROW_CLASS_ID)) {
return if (callReturnType.classId in SCHEMA_TYPES) {
val typeArguments = callReturnType.typeArguments
typeArguments
.mapNotNull {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api
import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.plugin.createConeType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.commonSuperTypeOrNull
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.types.type
Expand All @@ -15,6 +18,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.impl.api.withValuesImpl
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names

class DataFrameOf0 : AbstractInterpreter<DataFrameBuilderApproximation>() {
val Arguments.header: List<String> by arg()
Expand Down Expand Up @@ -53,3 +57,30 @@ class DataFrameOf3 : AbstractSchemaModificationInterpreter() {
return PluginDataFrameSchema(res)
}
}

abstract class SchemaConstructor : AbstractSchemaModificationInterpreter() {
val Arguments.columns: List<Interpreter.Success<Pair<*, *>>> by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
val res = columns.map {
val it = it.value
val name = (it.first as? FirLiteralExpression)?.value as? String
val resolvedType = (it.second as? FirExpression)?.resolvedType
val type: ConeKotlinType? = when (resolvedType?.classId) {
Names.COLUM_GROUP_CLASS_ID -> Names.DATA_ROW_CLASS_ID.createConeType(session, arrayOf(resolvedType.typeArguments[0]))
Names.FRAME_COLUMN_CLASS_ID -> Names.DF_CLASS_ID.createConeType(session, arrayOf(resolvedType.typeArguments[0]))
Names.DATA_COLUMN_CLASS_ID -> resolvedType.typeArguments[0] as? ConeKotlinType
Names.BASE_COLUMN_CLASS_ID -> resolvedType.typeArguments[0] as? ConeKotlinType
else -> null
}
if (name == null || type == null) return PluginDataFrameSchema(emptyList())
simpleColumnOf(name, type)
}
return PluginDataFrameSchema(res)
}
}

class DataFrameOfPairs : SchemaConstructor()

class ColumnOfPairs : SchemaConstructor()

Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsAtAnyDepth2
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf2
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnOfPairs
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnRange
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOfPairs
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameUnfold
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameXs
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop0
Expand Down Expand Up @@ -409,6 +411,8 @@ internal inline fun <reified T> String.load(): T {
"toDataFrameDefault" -> ToDataFrameDefault()
"ToDataFrameDslStringInvoke" -> ToDataFrameDslStringInvoke()
"DataFrameOf0" -> DataFrameOf0()
"DataFrameOfPairs" -> DataFrameOfPairs()
"ColumnOfPairs" -> ColumnOfPairs()
"DataFrameBuilderInvoke0" -> DataFrameBuilderInvoke0()
"ToDataFrameColumn" -> ToDataFrameColumn()
"FillNulls0" -> FillNulls0()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@ object Names {

val COLUM_GROUP_CLASS_ID: ClassId
get() = ClassId(FqName("org.jetbrains.kotlinx.dataframe.columns"), Name.identifier("ColumnGroup"))
val FRAME_COLUMN_CLASS_ID: ClassId
get() = ClassId(FqName("org.jetbrains.kotlinx.dataframe.columns"), Name.identifier("FrameColumn"))
val DATA_COLUMN_CLASS_ID: ClassId
get() = ClassId(
FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe")),
Name.identifier("DataColumn")
)
val BASE_COLUMN_CLASS_ID: ClassId
get() = ClassId(
FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe", "columns")),
Name.identifier("BaseColumn")
)
val COLUMNS_CONTAINER_CLASS_ID: ClassId
get() = ClassId(
FqName.fromSegments(listOf("org", "jetbrains", "kotlinx", "dataframe")),
Expand Down
15 changes: 15 additions & 0 deletions plugins/kotlin-dataframe/testData/box/columnOf_nested.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val group = columnOf(
"c" to columnOf("2"),
"d" to columnOf(123),
)
val str: DataColumn<String> = group.c
val i: DataColumn<Int> = group.d

return "OK"
}
20 changes: 20 additions & 0 deletions plugins/kotlin-dataframe/testData/box/dataFrameOf_nested.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf(
"a" to columnOf("1"),
"b" to columnOf(
"c" to columnOf("2"),
),
"d" to columnOf(dataFrameOf("a")(123)),
"gr" to listOf("1").toDataFrame().asColumnGroup(),
)
val str: DataColumn<String> = df.a
val str1: DataColumn<String> = df.b.c
val i: DataColumn<Int> = df.d[0].a
val str2: DataColumn<String> = df.gr.value
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ public void testColumnName_invalidSymbol() {
runTest("testData/box/columnName_invalidSymbol.kt");
}

@Test
@TestMetadata("columnOf_nested.kt")
public void testColumnOf_nested() {
runTest("testData/box/columnOf_nested.kt");
}

@Test
@TestMetadata("columnWithStarProjection.kt")
public void testColumnWithStarProjection() {
Expand Down Expand Up @@ -118,6 +124,12 @@ public void testDataFrameOf() {
runTest("testData/box/dataFrameOf.kt");
}

@Test
@TestMetadata("dataFrameOf_nested.kt")
public void testDataFrameOf_nested() {
runTest("testData/box/dataFrameOf_nested.kt");
}

@Test
@TestMetadata("dataFrameOf_to.kt")
public void testDataFrameOf_to() {
Expand Down
Loading