Skip to content

Commit 6e5c029

Browse files
committed
Check the correct scope when enforcing assisted injection to be unscoped
Fixes #301
1 parent fca9d88 commit 6e5c029

File tree

5 files changed

+61
-35
lines changed

5 files changed

+61
-35
lines changed

gradle/libs.versions.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[versions]
2-
kotlin-inject = "0.7.0-SNAPSHOT"
2+
kotlin-inject = "0.6.3-SNAPSHOT"
33
kotlin = "1.9.0"
44
ksp = "1.9.0-1.0.11"
55
kotlinpoet = "1.14.2"

integration-tests/common/src/test/kotlin/me/tatarka/inject/test/InjectFunctionTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ typealias receiverFun = String.(arg: NamedFoo) -> String
3535
fun String.receiverFun(dep: Foo, @Assisted arg: NamedFoo): String = this
3636

3737
@Component
38+
@CustomScope
3839
abstract class ReceiverFunctionInjectionComponent {
3940
abstract val receiverFun: receiverFun
4041
}

kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeCollector.kt

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,21 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
100100
}
101101

102102
for (method in typeInfo.providesMethods) {
103-
val scopeType = method.scopeType(options)
104-
if (scopeType != null && scopeType != typeInfo.elementScope) {
103+
val scope = method.scopeType(options)
104+
if (scope != null && scope != typeInfo.elementScope) {
105105
if (typeInfo.elementScope != null) {
106106
provider.error(
107-
"@Provides with scope: $scopeType must match component scope: ${typeInfo.elementScope}",
107+
"@Provides with scope: $scope must match component scope: ${typeInfo.elementScope}",
108108
method
109109
)
110110
} else {
111111
provider.error(
112-
"@Provides with scope: $scopeType cannot be provided in an unscoped component",
112+
"@Provides with scope: $scope cannot be provided in an unscoped component",
113113
method
114114
)
115115
}
116116
}
117-
val scopedComponent = if (scopeType != null) astClass else null
117+
val scopedComponent = if (scope != null) astClass else null
118118
if (method.hasAnnotation(INTO_MAP.packageName, INTO_MAP.simpleName)) {
119119
// Pair<A, B> -> Map<A, B>
120120
val returnType = method.returnTypeFor(astClass)
@@ -126,7 +126,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
126126
resolvedType.arguments[1],
127127
key.qualifier
128128
)
129-
addContainerType(provider, key, containerKey, method, accessor, scopedComponent)
129+
addContainerType(provider, key, containerKey, method, accessor, scope, scopedComponent)
130130
} else {
131131
provider.error("@IntoMap must have return type of type Pair", method)
132132
}
@@ -135,7 +135,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
135135
val returnType = method.returnTypeFor(astClass)
136136
val key = TypeKey(returnType, method.qualifier(options))
137137
val containerKey = ContainerKey.SetKey(returnType, key.qualifier)
138-
addContainerType(provider, key, containerKey, method, accessor, scopedComponent)
138+
addContainerType(provider, key, containerKey, method, accessor, scope, scopedComponent)
139139
} else {
140140
val returnType = method.returnTypeFor(astClass)
141141
val key = TypeKey(returnType, method.qualifier(options))
@@ -147,7 +147,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
147147
provider.error("@Provides method is not accessible", method)
148148
}
149149
}
150-
addMethod(key, method, accessor, scopedComponent)
150+
addMethod(key, method, accessor, scope, scopedComponent)
151151
}
152152
}
153153

@@ -190,6 +190,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
190190
containerKey: ContainerKey,
191191
method: AstMember,
192192
accessor: Accessor,
193+
scope: AstType?,
193194
scopedComponent: AstClass?,
194195
) {
195196
val current = type(containerKey.containerTypeKey(provider))
@@ -199,10 +200,16 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
199200
}
200201

201202
containerTypes.getOrPut(containerKey) { mutableListOf() }
202-
.add(method(method, accessor, scopedComponent))
203+
.add(method(method, accessor, scope, scopedComponent))
203204
}
204205

205-
private fun addMethod(key: TypeKey, method: AstMember, accessor: Accessor, scopedComponent: AstClass?) {
206+
private fun addMethod(
207+
key: TypeKey,
208+
method: AstMember,
209+
accessor: Accessor,
210+
scope: AstType?,
211+
scopedComponent: AstClass?,
212+
) {
206213
val oldValue = types[key]
207214
if (oldValue != null) {
208215
duplicate(key, newValue = method, oldValue = oldValue.method)
@@ -218,7 +225,7 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
218225
}
219226
}
220227

221-
types[key] = method(method, accessor, scopedComponent)
228+
types[key] = method(method, accessor, scope, scopedComponent)
222229
}
223230

224231
private fun addProviderMethod(key: TypeKey, method: AstMember, accessor: Accessor) {
@@ -228,9 +235,10 @@ class TypeCollector(private val provider: AstProvider, private val options: Opti
228235
}
229236
}
230237

231-
private fun method(method: AstMember, accessor: Accessor, scopedComponent: AstClass?) = Method(
238+
private fun method(method: AstMember, accessor: Accessor, scope: AstType?, scopedComponent: AstClass?) = Method(
232239
method = method,
233240
accessor = accessor,
241+
scope = scope,
234242
scopedComponent = scopedComponent
235243
)
236244

@@ -347,6 +355,7 @@ class ProviderMethod(
347355
class Method(
348356
val method: AstMember,
349357
val accessor: Accessor = Accessor.Empty,
358+
val scope: AstType? = null,
350359
val scopedComponent: AstClass? = null,
351360
)
352361

kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResult.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ sealed class TypeResult {
6666
*/
6767
class Constructor(
6868
val type: AstType,
69+
val scope: AstType?,
6970
val parameters: Map<String, TypeResultRef>,
7071
val supportsNamedArguments: Boolean
7172
) : TypeResult() {

kotlin-inject-compiler/core/src/main/kotlin/me/tatarka/inject/compiler/TypeResultResolver.kt

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
6161
private fun resolveParams(
6262
context: Context,
6363
element: AstElement,
64+
scope: AstType?,
6465
params: List<AstParam>,
6566
): Map<String, TypeResultRef> {
6667
return if (params.any { it.isAssisted() }) {
67-
resolveParamsNew(context, element, params)
68+
resolveParamsNew(context, element, scope, params)
6869
} else {
69-
resolveParamsLegacy(context, element, params)
70+
resolveParamsLegacy(context, element, scope, params)
7071
}
7172
}
7273

@@ -75,12 +76,12 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
7576
private fun resolveParamsNew(
7677
context: Context,
7778
element: AstElement,
79+
scope: AstType?,
7880
params: List<AstParam>,
7981
): Map<String, TypeResultRef> {
80-
if (context.scopeComponent != null) {
81-
val scopeType = context.scopeComponent.scopeType(options)
82+
if (scope != null) {
8283
throw FailedToGenerateException(
83-
"Cannot apply scope: @${scopeType!!.simpleName} to type with @Assisted parameters: [${
84+
"Cannot apply scope: @${scope.simpleName} to type with @Assisted parameters: [${
8485
params.filter { it.isAssisted() }.joinToString()
8586
}]"
8687
)
@@ -137,6 +138,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
137138
private fun resolveParamsLegacy(
138139
context: Context,
139140
element: AstElement,
141+
scope: AstType?,
140142
params: List<AstParam>,
141143
): Map<String, TypeResultRef> {
142144
val size = params.size
@@ -171,10 +173,9 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
171173
""".trimIndent(),
172174
element
173175
)
174-
if (context.scopeComponent != null) {
175-
val scopeType = context.scopeComponent.scopeType(options)
176+
if (scope != null) {
176177
throw FailedToGenerateException(
177-
"Cannot apply scope: @${scopeType!!.simpleName} to type with assisted parameters: [${
178+
"Cannot apply scope: @${scope.simpleName} to type with @Assisted parameters: [${
178179
resolvedImplicitly.joinToString()
179180
}]"
180181
)
@@ -199,6 +200,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
199200
context = withTypes(types),
200201
accessor = method.accessor,
201202
method = method.method,
203+
scope = null,
202204
key = key,
203205
)
204206
}
@@ -255,6 +257,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
255257
context = this,
256258
accessor = creator.accessor,
257259
method = creator.method,
260+
scope = creator.scope,
258261
key = key,
259262
)
260263
}
@@ -270,7 +273,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
270273
creator = containerKey.creator,
271274
args = args,
272275
mapArg = { key, arg, types ->
273-
Provides(withTypes(types), arg.accessor, arg.method, key)
276+
Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)
274277
}
275278
)
276279
}
@@ -284,7 +287,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
284287
args = args,
285288
mapArg = { key, arg, types ->
286289
Function(withTypes(types), args = innerType.arguments.dropLast(1)) { context ->
287-
TypeResultRef(key, Provides(context, arg.accessor, arg.method, key))
290+
TypeResultRef(key, Provides(context, arg.accessor, arg.method, arg.scope, key))
288291
}
289292
}
290293
)
@@ -299,7 +302,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
299302
args = args,
300303
mapArg = { key, arg, types ->
301304
Lazy(key) {
302-
TypeResultRef(key, Provides(withTypes(types), arg.accessor, arg.method, key))
305+
TypeResultRef(key, Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key))
303306
}
304307
}
305308
)
@@ -316,7 +319,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
316319
creator = containerKey.creator,
317320
args = args,
318321
mapArg = { key, arg, types ->
319-
Provides(withTypes(types), arg.accessor, arg.method, key)
322+
Provides(withTypes(types), arg.accessor, arg.method, arg.scope, key)
320323
}
321324
)
322325
}
@@ -376,6 +379,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
376379
Constructor(
377380
context = this,
378381
constructor = injectCtor,
382+
scope = scope,
379383
key = key,
380384
)
381385
}
@@ -406,6 +410,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
406410
context: Context,
407411
accessor: Accessor,
408412
method: AstMember,
413+
scope: AstType?,
409414
key: TypeKey,
410415
) = withCycleDetection(key, method) {
411416
TypeResult.Provides(
@@ -418,7 +423,7 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
418423
},
419424
isProperty = method is AstProperty,
420425
parameters = (method as? AstFunction)?.let {
421-
resolveParams(context, method, it.parameters)
426+
resolveParams(context, method, scope, it.parameters)
422427
} ?: emptyMap(),
423428
)
424429
}
@@ -435,14 +440,19 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
435440
result = resolve(context.withoutScoped(key.type, scopedComponent), element, key)
436441
)
437442

438-
private fun Constructor(context: Context, constructor: AstConstructor, key: TypeKey) =
439-
withCycleDetection(key, constructor) {
440-
TypeResult.Constructor(
441-
type = constructor.type,
442-
parameters = resolveParams(context, constructor, constructor.parameters),
443-
supportsNamedArguments = constructor.supportsNamedArguments
444-
)
445-
}
443+
private fun Constructor(
444+
context: Context,
445+
constructor: AstConstructor,
446+
scope: AstType?,
447+
key: TypeKey
448+
) = withCycleDetection(key, constructor) {
449+
TypeResult.Constructor(
450+
type = constructor.type,
451+
scope = scope,
452+
parameters = resolveParams(context, constructor, scope, constructor.parameters),
453+
supportsNamedArguments = constructor.supportsNamedArguments
454+
)
455+
}
446456

447457
private fun Container(
448458
creator: String,
@@ -492,7 +502,12 @@ class TypeResultResolver(private val provider: AstProvider, private val options:
492502
TypeResult.NamedFunction(
493503
name = function.toMemberName(),
494504
args = namedArgs.map { it.second },
495-
parameters = resolveParams(context.withArgs(namedArgs), function, function.parameters),
505+
parameters = resolveParams(
506+
context = context.withArgs(namedArgs),
507+
element = function,
508+
scope = null,
509+
params = function.parameters
510+
),
496511
)
497512
}
498513

0 commit comments

Comments
 (0)