Skip to content

[Compiler plugin] Setup call transformer pipeline to handle (...) -> DataRow functions #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ public fun DataFrame.Companion.readJsonStr(
* @param header Optional list of column names. If given, [text] will be read like an object with [header] being the keys.
* @return [DataRow] from the given [text].
*/
@Refine
@Interpretable("DataRowReadJsonStr")
public fun DataRow.Companion.readJsonStr(
@Language("json") text: String,
header: List<String> = emptyList(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class FunctionCallTransformer(
fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall?
}

private val transformers = listOf(GroupByCallTransformer(), DataFrameCallTransformer())
private val transformers = listOf(GroupByCallTransformer(), DataFrameCallTransformer(), DataRowCallTransformer())

override fun intercept(callInfo: CallInfo, symbol: FirNamedFunctionSymbol): CallReturnType? {
val callSiteAnnotations = (callInfo.callSite as? FirAnnotationContainer)?.annotations ?: emptyList()
Expand Down Expand Up @@ -156,14 +156,14 @@ class FunctionCallTransformer(
?: call
}

inner class DataFrameCallTransformer : CallTransformer {
inner class DataSchemaLikeCallTransformer(val classId: ClassId) : CallTransformer {
override fun interceptOrNull(callInfo: CallInfo, symbol: FirNamedFunctionSymbol, hash: String): CallReturnType? {
if (symbol.resolvedReturnType.fullyExpandedClassId(session) != Names.DF_CLASS_ID) return null
// possibly null if explicit receiver type is AnyFrame
if (symbol.resolvedReturnType.fullyExpandedClassId(session) != classId) return null
// possibly null if explicit receiver type is typealias
val argument = (callInfo.explicitReceiver?.resolvedType)?.typeArguments?.getOrNull(0)
val newDataFrameArgument = buildNewTypeArgument(argument, callInfo.name, hash)

val lookupTag = ConeClassLikeLookupTagImpl(Names.DF_CLASS_ID)
val lookupTag = ConeClassLikeLookupTagImpl(classId)
val typeRef = buildResolvedTypeRef {
type = ConeClassLikeTypeImpl(
lookupTag,
Expand All @@ -182,7 +182,7 @@ class FunctionCallTransformer(

@OptIn(SymbolInternals::class)
override fun transformOrNull(call: FirFunctionCall, originalSymbol: FirNamedFunctionSymbol): FirFunctionCall? {
val callResult = analyzeRefinedCallShape<PluginDataFrameSchema>(call, Names.DF_CLASS_ID, InterpretationErrorReporter.DEFAULT)
val callResult = analyzeRefinedCallShape<PluginDataFrameSchema>(call, classId, InterpretationErrorReporter.DEFAULT)
val (tokens, dataFrameSchema) = callResult ?: return null
val token = tokens[0]
val firstSchema = token.toClassSymbol(session)?.resolvedSuperTypes?.get(0)!!.toRegularClassSymbol(session)?.fir!!
Expand All @@ -195,6 +195,10 @@ class FunctionCallTransformer(
}
}

inner class DataFrameCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.DF_CLASS_ID)

inner class DataRowCallTransformer : CallTransformer by DataSchemaLikeCallTransformer(Names.DATA_ROW_CLASS_ID)

inner class GroupByCallTransformer : CallTransformer {
override fun interceptOrNull(
callInfo: CallInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ReturnTypeBasedReceiverInjector(session: FirSession) : FirExpressionResolu
@OptIn(SymbolInternals::class)
override fun addNewImplicitReceivers(functionCall: FirFunctionCall): List<ConeKotlinType> {
val callReturnType = functionCall.resolvedType
return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID)) {
return if (callReturnType.classId in setOf(Names.DF_CLASS_ID, Names.GROUP_BY_CLASS_ID, Names.DATA_ROW_CLASS_ID)) {
val typeArguments = callReturnType.typeArguments
typeArguments
.mapNotNull {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api
import kotlinx.serialization.decodeFromString
import kotlinx.serialization.json.Json
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
Expand Down Expand Up @@ -118,6 +119,15 @@ internal class ReadJsonStr : AbstractInterpreter<PluginDataFrameSchema>() {
}
}

internal class DataRowReadJsonStr : AbstractInterpreter<PluginDataFrameSchema>() {
val Arguments.text: String by arg()
val Arguments.typeClashTactic: JSON.TypeClashTactic by arg(defaultValue = Present(ARRAY_AND_VALUE_COLUMNS))

override fun Arguments.interpret(): PluginDataFrameSchema {
return DataRow.readJsonStr(text, typeClashTactic = typeClashTactic).schema().toPluginDataFrameSchema()
}
}

internal class ReadExcel : AbstractSchemaModificationInterpreter() {
val Arguments.fileOrUrl: String by arg()
val Arguments.sheetName: String? by arg(defaultValue = Present(null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColsOf1
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataRowReadJsonStr
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
Expand Down Expand Up @@ -218,6 +219,7 @@ internal inline fun <reified T> String.load(): T {
"DataFrameGroupBy" -> DataFrameGroupBy()
"GroupByInto" -> GroupByInto()
"ReadJsonStr" -> ReadJsonStr()
"DataRowReadJsonStr" -> DataRowReadJsonStr()
"ReadDelimStr" -> ReadDelimStr()
"GroupByToDataFrame" -> GroupByToDataFrame()
"ToDataFrameFrom0" -> ToDataFrameFrom()
Expand Down
13 changes: 13 additions & 0 deletions plugins/kotlin-dataframe/testData/box/readJsonStr_datarow.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

const val text = """{"a":"abc", "b":1}"""

fun box(): String {
val row = DataRow.readJsonStr(text)
row.a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be so awesome for an ordinary Map :o I wonder if we could have a plugin which just does this for Maps. It's like js/python but safe :P

row.b
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ public void testReadJsonStr_const() {
runTest("testData/box/readJsonStr_const.kt");
}

@Test
@TestMetadata("readJsonStr_datarow.kt")
public void testReadJsonStr_datarow() {
runTest("testData/box/readJsonStr_datarow.kt");
}

@Test
@TestMetadata("readJsonStr_localProperty.kt")
public void testReadJsonStr_localProperty() {
Expand Down
Loading