Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate Poko function declarations in FIR #465

Merged
merged 12 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ jobs:
-Dorg.gradle.project.pokoTests.jvmToolchainVersion=${{ matrix.poko_tests_jvm_toolchain_version }}
-Dorg.gradle.project.pokoTests.compileMode=WITHOUT_K2

test-fir-generation:
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v4

- name: Install JDK ${{ matrix.poko_tests_jvm_toolchain_version }}
uses: actions/setup-java@v4
with:
distribution: zulu
java-version: 22

- name: Set up Gradle
uses: gradle/actions/setup-gradle@v4

- name: Test
run: ./gradlew :poko-tests:build --stacktrace -Dorg.gradle.project.pokoTests.compileMode=FIR_GENERATION_ENABLED

build-sample:
runs-on: ubuntu-latest
needs: build
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ public class PokoCompilerPluginRegistrar : CompilerPluginRegistrar() {

override val supportsK2: Boolean get() = true

private val knownPokoPluginArgs = emptySet<String>()
private val firDeclarationGenerationPluginArg =
"poko.experimental.enableFirDeclarationGeneration"
private val knownPokoPluginArgs = setOf(
firDeclarationGenerationPluginArg,
)

override fun ExtensionStorage.registerExtensions(configuration: CompilerConfiguration) {
if (!configuration.get(CompilerOptions.ENABLED, DEFAULT_POKO_ENABLED))
Expand All @@ -45,12 +49,29 @@ public class PokoCompilerPluginRegistrar : CompilerPluginRegistrar() {
}
}

val firDeclarationGenerationPluginValue = pokoPluginArgs[firDeclarationGenerationPluginArg]
val firDeclarationGenerationEnabled =
firDeclarationGenerationPluginValue?.toBoolean()?.also {
messageCollector.report(
severity = CompilerMessageSeverity.WARNING,
message = "<$firDeclarationGenerationPluginArg> resolved to $it. " +
"This experimental flag may disappear at any time.",
)
} ?: false

IrGenerationExtension.registerExtension(
PokoIrGenerationExtension(pokoAnnotationClassId, messageCollector)
PokoIrGenerationExtension(
pokoAnnotationName = pokoAnnotationClassId,
firDeclarationGeneration = firDeclarationGenerationEnabled,
messageCollector = messageCollector,
)
)

FirExtensionRegistrarAdapter.registerExtension(
PokoFirExtensionRegistrar(pokoAnnotationClassId)
PokoFirExtensionRegistrar(
pokoAnnotation = pokoAnnotationClassId,
declarationGeneration = firDeclarationGenerationEnabled,
)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package dev.drewhamilton.poko

import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.util.OperatorNameConventions

/**
* Exhaustive representation of all functions Poko generates.
*/
internal enum class PokoFunction(
val functionName: Name,
) {
Equals(OperatorNameConventions.EQUALS),
HashCode(OperatorNameConventions.HASH_CODE),
ToString(OperatorNameConventions.TO_STRING),
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package dev.drewhamilton.poko.fir

import dev.drewhamilton.poko.PokoFunction
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.extensions.FirDeclarationGenerationExtension
import org.jetbrains.kotlin.fir.extensions.FirDeclarationPredicateRegistrar
import org.jetbrains.kotlin.fir.extensions.MemberGenerationContext
import org.jetbrains.kotlin.fir.extensions.predicate.LookupPredicate
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
import org.jetbrains.kotlin.fir.plugin.createMemberFunction
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.isExtension
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.runIf

internal class PokoFirDeclarationGenerationExtension(
session: FirSession,
) : FirDeclarationGenerationExtension(session) {
private val pokoAnnotation by lazy {
session.pokoFirExtensionSessionComponent.pokoAnnotation
}

private val pokoAnnotationPredicate by lazy {
LookupPredicate.create {
annotated(pokoAnnotation.asSingleFqName())
}
}

/**
* Pairs of <Poko.Builder ClassId, outer class Symbol>.
*/
private val pokoClasses by lazy {
session.predicateBasedProvider.getSymbolsByPredicate(pokoAnnotationPredicate)
.filterIsInstance<FirRegularClassSymbol>()
}

override fun FirDeclarationPredicateRegistrar.registerPredicates() {
register(pokoAnnotationPredicate)
}

override fun getCallableNamesForClass(
classSymbol: FirClassSymbol<*>,
context: MemberGenerationContext,
): Set<Name> = when {
classSymbol in pokoClasses -> PokoFunction.entries.map { it.functionName }.toSet()
else -> emptySet()
}

override fun generateFunctions(
callableId: CallableId,
context: MemberGenerationContext?
): List<FirNamedFunctionSymbol> {
val owner = context?.owner ?: return emptyList()

val callableName = callableId.callableName
val function = when (callableName) {
PokoFunction.Equals.functionName -> runIf(!owner.hasDeclaredEqualsFunction()) {
createEqualsFunction(owner)
}

PokoFunction.HashCode.functionName -> runIf(!owner.hasDeclaredHashCodeFunction()) {
createHashCodeFunction(owner)
}

PokoFunction.ToString.functionName -> runIf(!owner.hasDeclaredToStringFunction()) {
createToStringFunction(owner)
}

else -> null
}
return function?.let { listOf(it.symbol) } ?: emptyList()
}

//region equals
private fun FirClassSymbol<*>.hasDeclaredEqualsFunction(): Boolean {
return declarationSymbols
.filterIsInstance<FirNamedFunctionSymbol>()
.any {
!it.isExtension &&
it.name == PokoFunction.Equals.functionName &&
it.valueParameterSymbols.size == 1 &&
it.valueParameterSymbols
.single()
.resolvedReturnType == session.builtinTypes.nullableAnyType.coneType
}
}

private fun createEqualsFunction(
owner: FirClassSymbol<*>,
): FirSimpleFunction = createMemberFunction(
owner = owner,
key = PokoKey,
name = PokoFunction.Equals.functionName,
returnType = session.builtinTypes.booleanType.coneType,
) {
valueParameter(
name = Name.identifier("other"),
type = session.builtinTypes.nullableAnyType.coneType,
key = PokoKey,
)
}
//endregion

//region hashCode
private fun FirClassSymbol<*>.hasDeclaredHashCodeFunction(): Boolean {
return declarationSymbols
.filterIsInstance<FirNamedFunctionSymbol>()
.any {
!it.isExtension &&
it.name == PokoFunction.HashCode.functionName &&
it.valueParameterSymbols.isEmpty()
}
}

private fun createHashCodeFunction(
owner: FirClassSymbol<*>,
): FirSimpleFunction = createMemberFunction(
owner = owner,
key = PokoKey,
name = PokoFunction.HashCode.functionName,
returnType = session.builtinTypes.intType.coneType,
)
//endregion

//region toString
private fun FirClassSymbol<*>.hasDeclaredToStringFunction(): Boolean {
return declarationSymbols
.filterIsInstance<FirNamedFunctionSymbol>()
.any {
!it.isExtension &&
it.name == PokoFunction.ToString.functionName &&
it.valueParameterSymbols.isEmpty()
}
}

private fun createToStringFunction(
owner: FirClassSymbol<*>,
): FirSimpleFunction = createMemberFunction(
owner = owner,
key = PokoKey,
name = PokoFunction.ToString.functionName,
returnType = session.builtinTypes.stringType.coneType,
)
//endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ import org.jetbrains.kotlin.name.ClassId

internal class PokoFirExtensionRegistrar(
private val pokoAnnotation: ClassId,
private val declarationGeneration: Boolean,
) : FirExtensionRegistrar() {
override fun ExtensionRegistrarContext.configurePlugin() {
+PokoFirExtensionSessionComponent.getFactory(pokoAnnotation)
+::PokoFirCheckersExtension
if (declarationGeneration) {
+::PokoFirDeclarationGenerationExtension
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package dev.drewhamilton.poko.fir

import org.jetbrains.kotlin.GeneratedDeclarationKey

internal object PokoKey : GeneratedDeclarationKey() {
override fun toString() = "FirPoko"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package dev.drewhamilton.poko.ir

import dev.drewhamilton.poko.PokoFunction
import dev.drewhamilton.poko.fir.PokoKey
import org.jetbrains.kotlin.GeneratedDeclarationKey
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
import org.jetbrains.kotlin.backend.common.lower.DeclarationIrBuilder
import org.jetbrains.kotlin.cli.common.messages.MessageCollector
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.builders.irBlockBody
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin.GeneratedByPlugin
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
import org.jetbrains.kotlin.ir.util.isEquals
import org.jetbrains.kotlin.ir.util.isHashCode
import org.jetbrains.kotlin.ir.util.isToString
import org.jetbrains.kotlin.ir.util.parentAsClass
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
import org.jetbrains.kotlin.name.ClassId

@OptIn(UnsafeDuringIrConstructionAPI::class)
internal class PokoFunctionBodyFiller(
private val pokoAnnotation: ClassId,
private val context: IrPluginContext,
private val messageCollector: MessageCollector,
) : IrElementVisitorVoid {

override fun visitSimpleFunction(declaration: IrSimpleFunction) {
val origin = declaration.origin
if (origin !is GeneratedByPlugin || !interestedIn(origin.pluginKey)) {
return
}

require(declaration.body == null)

val pokoFunction = when {
declaration.isEquals() -> PokoFunction.Equals
declaration.isHashCode() -> PokoFunction.HashCode
declaration.isToString() -> PokoFunction.ToString
else -> return
}

val pokoClass = declaration.parentAsClass
val pokoProperties = pokoClass.pokoProperties(pokoAnnotation).also {
if (it.isEmpty()) {
messageCollector.log("No primary constructor properties")
messageCollector.reportErrorOnClass(
irClass = pokoClass,
message = "Poko class primary constructor must have at least one not-skipped property",
)
}
}

declaration.body = DeclarationIrBuilder(
generatorContext = context,
symbol = declaration.symbol,
).irBlockBody {
when (pokoFunction) {
PokoFunction.Equals -> generateEqualsMethodBody(
pokoAnnotation = pokoAnnotation,
context = this@PokoFunctionBodyFiller.context,
irClass = pokoClass,
functionDeclaration = declaration,
classProperties = pokoProperties,
messageCollector = messageCollector,
)

PokoFunction.HashCode -> generateHashCodeMethodBody(
pokoAnnotation = pokoAnnotation,
context = this@PokoFunctionBodyFiller.context,
functionDeclaration = declaration,
classProperties = pokoProperties,
messageCollector = messageCollector,
)

PokoFunction.ToString -> generateToStringMethodBody(
pokoAnnotation = pokoAnnotation,
context = this@PokoFunctionBodyFiller.context,
irClass = pokoClass,
functionDeclaration = declaration,
classProperties = pokoProperties,
messageCollector = messageCollector,
)
}
}
}

private fun interestedIn(
key: GeneratedDeclarationKey?,
): Boolean {
return key == PokoKey
}

override fun visitElement(element: IrElement) {
when (element) {
is IrDeclaration, is IrFile, is IrModuleFragment -> element.acceptChildrenVoid(this)
else -> Unit
}
}
}
Loading
Loading