Skip to content

[Compiler plugin] Support dataFrameOf(Pair<String, List<T>) #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ public inline fun <reified C> dataFrameOf(vararg header: String, fill: (String)

public fun dataFrameOf(header: Iterable<String>): DataFrameBuilder = DataFrameBuilder(header.asList())

@Refine
@Interpretable("DataFrameOf3")
public fun dataFrameOf(vararg columns: Pair<String, List<Any?>>): DataFrame<*> =
columns.map { it.second.toColumn(it.first, Infer.Type) }.toDataFrame()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FunctionCallTransformer(
val tokenFir = token.toClassSymbol(session)!!.fir
tokenFir.callShapeData = CallShapeData.RefinedType(dataSchemaApis.map { it.scope.symbol })

return buildLetCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
return buildScopeFunctionCall(call, originalSymbol, dataSchemaApis, listOf(tokenFir))
}
}

Expand Down Expand Up @@ -253,7 +253,7 @@ class FunctionCallTransformer(
val keyToken = groupMarker.toClassSymbol(session)!!.fir
keyToken.callShapeData = CallShapeData.RefinedType(groupApis.map { it.scope.symbol })

return buildLetCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
return buildScopeFunctionCall(call, originalSymbol, keyApis + groupApis, additionalDeclarations = listOf(groupToken, keyToken))
}
}

Expand Down Expand Up @@ -305,18 +305,17 @@ class FunctionCallTransformer(
private fun Name.asTokenName() = identifierOrNullIfSpecial?.titleCase() ?: DEFAULT_NAME

@OptIn(SymbolInternals::class)
private fun buildLetCall(
private fun buildScopeFunctionCall(
call: FirFunctionCall,
originalSymbol: FirNamedFunctionSymbol,
dataSchemaApis: List<DataSchemaApi>,
additionalDeclarations: List<FirClass>
): FirFunctionCall {

val explicitReceiver = call.explicitReceiver ?: return call
val receiverType = explicitReceiver.resolvedType
val explicitReceiver = call.explicitReceiver
val receiverType = explicitReceiver?.resolvedType
val returnType = call.resolvedType
val resolvedLet = findLet()
val parameter = resolvedLet.valueParameterSymbols[0]
val scopeFunction = if (explicitReceiver != null) findLet() else findRun()

// original call is inserted later
call.transformCalleeReference(object : FirTransformer<Nothing?>() {
Expand Down Expand Up @@ -350,20 +349,23 @@ class FunctionCallTransformer(
returnTypeRef = buildResolvedTypeRef {
type = returnType
}
val itName = Name.identifier("it")
val parameterSymbol = FirValueParameterSymbol(itName)
valueParameters += buildValueParameter {
moduleData = session.moduleData
origin = FirDeclarationOrigin.Source
returnTypeRef = buildResolvedTypeRef {
type = receiverType
val parameterSymbol = receiverType?.let {
val itName = Name.identifier("it")
val parameterSymbol = FirValueParameterSymbol(itName)
valueParameters += buildValueParameter {
moduleData = session.moduleData
origin = FirDeclarationOrigin.Source
returnTypeRef = buildResolvedTypeRef {
type = receiverType
}
this.name = itName
this.symbol = parameterSymbol
containingFunctionSymbol = fSymbol
isCrossinline = false
isNoinline = false
isVararg = false
}
this.name = itName
this.symbol = parameterSymbol
containingFunctionSymbol = fSymbol
isCrossinline = false
isNoinline = false
isVararg = false
parameterSymbol
}
body = buildBlock {
this.coneTypeOrNull = returnType
Expand All @@ -375,20 +377,23 @@ class FunctionCallTransformer(
statements += additionalDeclarations

statements += buildReturnExpression {
val itPropertyAccess = buildPropertyAccessExpression {
coneTypeOrNull = receiverType
calleeReference = buildResolvedNamedReference {
name = itName
resolvedSymbol = parameterSymbol
if (parameterSymbol != null) {
val itPropertyAccess = buildPropertyAccessExpression {
coneTypeOrNull = receiverType
calleeReference = buildResolvedNamedReference {
name = parameterSymbol.name
resolvedSymbol = parameterSymbol
}
}
if (callDispatchReceiver != null) {
call.replaceDispatchReceiver(itPropertyAccess)
}
call.replaceExplicitReceiver(itPropertyAccess)
if (callExtensionReceiver != null) {
call.replaceExtensionReceiver(itPropertyAccess)
}
}
if (callDispatchReceiver != null) {
call.replaceDispatchReceiver(itPropertyAccess)
}
call.replaceExplicitReceiver(itPropertyAccess)
if (callExtensionReceiver != null) {
call.replaceExtensionReceiver(itPropertyAccess)
}

result = call
this.target = target
}
Expand All @@ -397,11 +402,19 @@ class FunctionCallTransformer(
isLambda = true
hasExplicitParameterList = false
typeRef = buildResolvedTypeRef {
type = ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
typeArguments = arrayOf(receiverType, returnType),
isNullable = false
)
type = if (receiverType != null) {
ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function1"))),
typeArguments = arrayOf(receiverType, returnType),
isNullable = false
)
} else {
ConeClassLikeTypeImpl(
ConeClassLikeLookupTagImpl(ClassId(FqName("kotlin"), Name.identifier("Function0"))),
typeArguments = arrayOf(returnType),
isNullable = false
)
}
}
invocationKind = EventOccurrencesRange.EXACTLY_ONCE
inlineStatus = InlineStatus.Inline
Expand All @@ -413,11 +426,13 @@ class FunctionCallTransformer(
val newCall1 = buildFunctionCall {
source = call.source
this.coneTypeOrNull = returnType
typeArguments += buildTypeProjectionWithVariance {
typeRef = buildResolvedTypeRef {
type = receiverType
if (receiverType != null) {
typeArguments += buildTypeProjectionWithVariance {
typeRef = buildResolvedTypeRef {
type = receiverType
}
variance = Variance.INVARIANT
}
variance = Variance.INVARIANT
}

typeArguments += buildTypeProjectionWithVariance {
Expand All @@ -429,11 +444,14 @@ class FunctionCallTransformer(
dispatchReceiver = null
this.explicitReceiver = callExplicitReceiver
extensionReceiver = callExtensionReceiver ?: callDispatchReceiver
argumentList = buildResolvedArgumentList(original = null, linkedMapOf(argument to parameter.fir))
argumentList = buildResolvedArgumentList(
original = null,
linkedMapOf(argument to scopeFunction.valueParameterSymbols[0].fir)
)
calleeReference = buildResolvedNamedReference {
source = call.calleeReference.source
this.name = Name.identifier("let")
resolvedSymbol = resolvedLet
this.name = scopeFunction.name
resolvedSymbol = scopeFunction
}
}
return newCall1
Expand Down Expand Up @@ -565,5 +583,9 @@ class FunctionCallTransformer(
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("let")).single()
}

private fun findRun(): FirFunctionSymbol<*> {
return session.symbolProvider.getTopLevelFunctionSymbols(FqName("kotlin"), Name.identifier("run")).single { it.typeParameterSymbols.size == 1 }
}

private fun String.titleCase() = replaceFirstChar { it.uppercaseChar() }
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlin.fir.expressions.FirExpression
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirVarargArgumentsExpression
import org.jetbrains.kotlin.fir.types.commonSuperTypeOrNull
import org.jetbrains.kotlin.fir.types.resolvedType
import org.jetbrains.kotlin.fir.types.type
import org.jetbrains.kotlin.fir.types.typeContext
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
Expand Down Expand Up @@ -36,3 +39,18 @@ class DataFrameBuilderInvoke0 : AbstractSchemaModificationInterpreter() {
return PluginDataFrameSchema(columns)
}
}

class DataFrameOf3 : AbstractSchemaModificationInterpreter() {
val Arguments.columns: List<Interpreter.Success<Pair<*, *>>> by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
val res = columns.map {
val it = it.value
val name = (it.first as? FirLiteralExpression)?.value as? String
val type = (it.second as? FirExpression)?.resolvedType?.typeArguments?.getOrNull(0)?.type
if (name == null || type == null) return PluginDataFrameSchema(emptyList())
simpleColumnOf(name, type)
}
return PluginDataFrameSchema(res)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Interpreter

class PairConstructor : AbstractInterpreter<Pair<*, *>>() {
val Arguments.receiver: Any? by arg(lens = Interpreter.Id)
val Arguments.that: Any? by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): Pair<*, *> {
return receiver to that
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ fun <T> KotlinTypeFacade.interpret(
is FirCallableReferenceAccess -> {
toKPropertyApproximation(it, session)
}

is FirFunctionCall -> {
it.loadInterpreter()?.let { processor ->
interpret(it, processor, emptyMap(), reporter)
}
}
else -> null
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,20 @@ import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirGetClassCall
import org.jetbrains.kotlin.fir.expressions.FirLiteralExpression
import org.jetbrains.kotlin.fir.expressions.FirResolvedQualifier
import org.jetbrains.kotlin.fir.expressions.UnresolvedExpressionTypeAccess
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedNamedFunctionSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.types.classId
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddDslStringInvoke
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.AddId
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Aggregate
Expand All @@ -76,12 +83,14 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.PairConstructor
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReadExcel
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameColumn
Expand All @@ -91,8 +100,16 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToDataFrameFrom
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ToTop
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Update0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.UpdateWith0
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names

@OptIn(UnresolvedExpressionTypeAccess::class)
internal fun FirFunctionCall.loadInterpreter(session: FirSession): Interpreter<*>? {
if (
calleeReference.toResolvedNamedFunctionSymbol()?.callableId == Names.TO &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so.. does this mean it fails for dataFrameOf(Pair("name", listOf(1, 2, 3)))?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably fine, most people will use to anyway, but it might be something we should mention in a doc somewhere, for that one person who's going to be very confused as to why their Pair is not recognized by the plugin if the function only accepts pairs as the type argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let's support Pair after i put in place more generic way for handling stdlib functions other than "if". And indeed with such things it's much easier to run into something plugin won't understand. One of the things that can be addressed with custom Checker that will highlight it as warning

coneTypeOrNull?.classId == Names.PAIR
) {
return PairConstructor()
}
val symbol =
(calleeReference as? FirResolvedNamedReference)?.resolvedSymbol as? FirCallableSymbol ?: return null
val argName = Name.identifier("interpreter")
Expand Down Expand Up @@ -208,6 +225,7 @@ internal inline fun <reified T> String.load(): T {
"ToTop" -> ToTop()
"Update0" -> Update0()
"Aggregate" -> Aggregate()
"DataFrameOf3" -> DataFrameOf3()
else -> error("$this")
} as T
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dataframe.plugin.utils

import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
Expand Down Expand Up @@ -50,6 +51,9 @@ object Names {
val LOCAL_DATE_CLASS_ID = kotlinx.datetime.LocalDate::class.classId()
val LOCAL_DATE_TIME_CLASS_ID = kotlinx.datetime.LocalDateTime::class.classId()
val INSTANT_CLASS_ID = kotlinx.datetime.Instant::class.classId()

val PAIR = ClassId(FqName("kotlin"), Name.identifier("Pair"))
val TO = CallableId(FqName("kotlin"), Name.identifier("to"))
}

private fun KClass<*>.classId(): ClassId {
Expand Down
14 changes: 14 additions & 0 deletions plugins/kotlin-dataframe/testData/box/dataFrameOf_to.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf(
"a" to listOf(1, 2),
"b" to listOf("str1", "str2"),
)
val i: Int = df.a[0]
val str: String = df.b[0]
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ public void testDataFrameOf() {
runTest("testData/box/dataFrameOf.kt");
}

@Test
@TestMetadata("dataFrameOf_to.kt")
public void testDataFrameOf_to() {
runTest("testData/box/dataFrameOf_to.kt");
}

@Test
@TestMetadata("dataFrameOf_vararg.kt")
public void testDataFrameOf_vararg() {
Expand Down
Loading