Skip to content

w1s3one805:w1s3one805-cycle-fixes #453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package me.tatarka.inject.test
import me.tatarka.inject.annotations.Component
import me.tatarka.inject.annotations.Inject
import me.tatarka.inject.annotations.Provides
import me.tatarka.inject.annotations.Scope

interface Lvl1
interface Lvl2
Expand Down Expand Up @@ -35,4 +36,58 @@ abstract class ComplexCycleComponent {
fun lvl3(impl: Lvl3Impl): Lvl3 = impl
@Provides
fun lvl4(impl: Lvl4Impl): Lvl4 = impl
}
}

@Scope
annotation class Singleton

@Singleton
@Inject
class Interceptor(val client1: Lazy<Client1>, val client2: Lazy<Client2>)

@Inject
class Client1(val interceptor: Interceptor)

@Inject
class Client2(val interceptor: Interceptor)

@Inject
class Repository(val client1: Client1, val client2: Client2)

@Singleton
@Component
interface MyComponent {
val interceptor: Interceptor
val repository: Repository
}

//class MyComponentImpl : MyComponent, ScopedComponent {
// override val _scoped: LazyMap = LazyMap()
//
// override val repository: Repository
// get() {
// val interceptor = _scoped.get("me.tatarka.inject.test.Interceptor") {
// run<Interceptor> {
// lateinit var interceptor: Interceptor
// Interceptor(
// client1 = lazy {
// Client1(
// interceptor = interceptor
// )
// },
// client2 = lazy {
// Client2(
// interceptor = interceptor
// )
// }
// ).also {
// interceptor = it
// }
// }
// }
// return Repository(
// client1 = Client1(interceptor),
// client2 = Client2(interceptor)
// )
// }
//}
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package me.tatarka.inject.compiler

import me.tatarka.kotlin.ast.AstElement
import me.tatarka.kotlin.ast.AstProvider

/**
* Builds of a stack of elements visited to see if we hit a cycle. Normally a cycle will cause a compile error. However
* the cycle can be 'broken' by delaying construction.
*/
class CycleDetector {
class CycleDetector<K, V> {

private val entries = mutableListOf<Entry<V>>()
private val resolving = mutableSetOf<K>()

private val entries = mutableListOf<Entry>()
private val resolving = mutableMapOf<TypeKey, String>()
private sealed class Entry<out V> {
class Element<V>(val value: V) : Entry<V>()
data object Delayed : Entry<Nothing>()
}

/**
* Denote that construction is being delayed. A cycle that crosses a delayed element can be resolved.
Expand All @@ -20,26 +22,25 @@ class CycleDetector {
}

/**
* Returns the variable name if [CycleResult.Resolvable] was hit for the given type key lower in the tree. This
* Returns the variable name if [LegacyCycleResult.Resolvable] was hit for the given type key lower in the tree. This
* means you should create the variable that was referenced.
*/
fun hitResolvable(key: TypeKey): String? {
fun hitResolvable(key: K): Boolean {
return resolving.remove(key)
}

/**
* Checks that the given element with the given type will produce a cycle. The result is provided in the given block
* so that then state can be rest after recursing.
*
* @see CycleResult
* @see LegacyCycleResult
*/
fun <T> check(key: TypeKey, element: AstElement, block: (CycleResult) -> T): T {
fun <T> check(key: K, element: V, block: (CycleResult<K>) -> T): T {
val lastRepeatIndex = entries.indexOfLast { it is Entry.Element && it.value == element }
val cycleResult = if (lastRepeatIndex != -1) {
if (entries.indexOfLast { it is Entry.Delayed } > lastRepeatIndex) {
val name = key.type.toVariableName()
resolving[key] = name
CycleResult.Resolvable(name)
resolving.add(key)
CycleResult.Resolvable(key)
} else {
CycleResult.Cycle
}
Expand All @@ -60,34 +61,29 @@ class CycleDetector {
/**
* Produce a trace of visited elements.
*/
fun trace(provider: AstProvider): String = entries.mapNotNull {
fun trace(toTrace: (V) -> String): String = entries.mapNotNull {
// filter only elements with a source.
when (it) {
is Entry.Element -> it.value
else -> null
}
}.reversed().joinToString(separator = "\n") { with(provider) { it.toTrace() } }

private sealed class Entry {
class Element(val value: AstElement) : Entry()
data object Delayed : Entry()
}
}.reversed().joinToString(separator = "\n", transform = toTrace)
}

sealed class CycleResult {
sealed class CycleResult<out K> {
/**
* There was no cycle, you may proceed normally.
*/
data object None : CycleResult()
data object None : CycleResult<Nothing>()

/**
* There was a cycle, you should error out.
*/
data object Cycle : CycleResult()
data object Cycle : CycleResult<Nothing>()

/**
* There was a cycle but it was across a delayed construction so it can be resolved. Reference the variable here
* with the given name and call [CycleDetector.hitResolvable] higher up the tree to create it.
*/
class Resolvable(val name: String) : CycleResult()
}
data class Resolvable<K>(val key: K) : CycleResult<K>()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package me.tatarka.inject.compiler

import me.tatarka.kotlin.ast.AstElement
import me.tatarka.kotlin.ast.AstProvider

/**
* Builds of a stack of elements visited to see if we hit a cycle. Normally a cycle will cause a compile error. However
* the cycle can be 'broken' by delaying construction.
*/
class LegacyCycleDetector {
private val detector = CycleDetector<TypeKey, AstElement>()

/**
* Denote that construction is being delayed. A cycle that crosses a delayed element can be resolved.
*/
fun delayedConstruction() {
detector.delayedConstruction()
}

/**
* Returns the variable name if [LegacyCycleResult.Resolvable] was hit for the given type key lower in the tree. This
* means you should create the variable that was referenced.
*/
fun hitResolvable(key: TypeKey): String? {
return if (detector.hitResolvable(key)) {
key.type.toVariableName()
} else {
null
}
}

/**
* Checks that the given element with the given type will produce a cycle. The result is provided in the given block
* so that then state can be rest after recursing.
*
* @see LegacyCycleResult
*/
fun <T> check(key: TypeKey, element: AstElement, block: (LegacyCycleResult) -> T): T {
return detector.check(key, element) { result ->
block(when(result) {
CycleResult.None -> LegacyCycleResult.None
CycleResult.Cycle -> LegacyCycleResult.Cycle
is CycleResult.Resolvable -> LegacyCycleResult.Resolvable(result.key.type.toVariableName())
})
}
}

/**
* Produce a trace of visited elements.
*/
fun trace(provider: AstProvider): String = detector.trace { with(provider) { it.toTrace() } }
}

sealed class LegacyCycleResult {
/**
* There was no cycle, you may proceed normally.
*/
data object None : LegacyCycleResult()

/**
* There was a cycle, you should error out.
*/
data object Cycle : LegacyCycleResult()

/**
* There was a cycle but it was across a delayed construction so it can be resolved. Reference the variable here
* with the given name and call [LegacyCycleDetector.hitResolvable] higher up the tree to create it.
*/
class Resolvable(val name: String) : LegacyCycleResult()
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import me.tatarka.kotlin.ast.AstType
@Suppress("NAME_SHADOWING", "FunctionNaming", "FunctionName")
class TypeResultResolver(private val provider: AstProvider, private val options: Options) {

private val cycleDetector = CycleDetector()
private val cycleDetector = CycleDetector<TypeKey, AstElement>()
private val typeCache = mutableMapOf<TypeCacheKey, TypeResult>()

/**
Expand Down Expand Up @@ -246,7 +246,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
val astClass = key.type.toAstClass()
val injectCtor = astClass.findInjectConstructors(provider.messenger, options)
if (injectCtor != null) {
return constructor(key, injectCtor, astClass)
return constructor(key, injectCtor.asMemberOf(element), astClass)
}

if (astClass.isInject() && astClass.isObject) {
Expand Down Expand Up @@ -563,7 +563,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
when (cycleResult) {
is CycleResult.None -> f()
is CycleResult.Cycle -> throw FailedToGenerateException(trace("Cycle detected"))
is CycleResult.Resolvable -> TypeResult.LocalVar(cycleResult.name)
is CycleResult.Resolvable -> TypeResult.LocalVar(cycleResult.key.type.toVariableName())
}
}
return maybeLateInit(key, result)
Expand All @@ -575,15 +575,17 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
result !is TypeResult.LocalVar && result !is TypeResult.Lazy &&
result !is TypeResult.Function && result !is TypeResult.Scoped
if (!validResultType) return result
val name = cycleDetector.hitResolvable(key) ?: return result
return LateInit(name, key, result)
if (!cycleDetector.hitResolvable(key)) return result
return LateInit(key.type.toVariableName(), key, result)
}

/**
* Produce a trace with the given message prefix. This will show all the lines with
* elements that were traversed for this context.
*/
private fun trace(message: String): String = "$message\n" + cycleDetector.trace(provider)
private fun trace(message: String): String = "$message\n" + cycleDetector.trace {
with(provider) { it.toTrace() }
}

private fun cannotFind(key: TypeKey): String = trace("Cannot find an @Inject constructor or provider for: $key")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package me.tatarka.inject.compiler

import assertk.assertThat
import assertk.assertions.isEqualTo
import kotlin.test.Test

class CycleDetectorTest {
private val cycleDetector = CycleDetector<String, Element>()

@Test
fun detects_single_element_cycle() {
lateinit var element: Element
element = Element("name") { listOf(element) }

val result = checkForCycle(element)

assertThat(result).isEqualTo(CycleResult.Cycle)
}

@Test
fun detects_two_element_cycle() {
lateinit var element: Element
element = Element("name1") { listOf(Element("name2") { listOf(element) }) }

val result = checkForCycle(element)

assertThat(result).isEqualTo(CycleResult.Cycle)
}

@Test
fun delay_breaks_cycle() {
lateinit var element: Element
element = Element("name", delayed = true) { listOf(element) }

val result = checkForCycle(element)

assertThat(result).isEqualTo(CycleResult.Resolvable("name"))
}

private fun checkForCycle(element: Element): CycleResult<String> {
if (element.delayed) {
cycleDetector.delayedConstruction()
}
return cycleDetector.check(element.name, element) { result ->
if (result != CycleResult.None) return@check result
for (ref in element.references()) {
val refResult = checkForCycle(ref)
if (refResult != CycleResult.None) {
return@check refResult
}
}
CycleResult.None
}
}

private class Element(
val name: String,
val delayed: Boolean = false,
val references: () -> List<Element>,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as Element
return name == other.name
}

override fun hashCode(): Int {
return name.hashCode()
}

override fun toString(): String {
return name
}
}
}