@@ -497,22 +497,33 @@ object CheckDslScopeViolation : ResolutionStage() {
497
497
context : ResolutionContext ,
498
498
): Set <ClassId > {
499
499
return buildSet {
500
- (boundSymbol.containingDeclarationIfParameter() as ? FirAnonymousFunctionSymbol )?.fir?.matchingParameterFunctionType?.let {
501
- // collect annotations in the function type at declaration site. For example, the `@A` and `@B` in the following code.
500
+ (boundSymbol.containingDeclarationIfParameter() as ? FirAnonymousFunctionSymbol )?.let { anonymousFunctionSymbol ->
501
+ val matchingParameterFunctionType = anonymousFunctionSymbol.fir.matchingParameterFunctionType ? : return @let
502
+
503
+ // Collect annotations in the function type at declaration site. For example, the `@A`, `@B` and `@C in the following code.
502
504
// ```
503
- // fun <T> body(block: @A ((@B T).() -> Unit)) { ... }
505
+ // fun <T, R > body(block: @A (context (@B T) (@C R ).() -> Unit)) { ... }
504
506
// ```
507
+ // @A should be collected unconditionally.
508
+ // @B should only be collected if `boundSymbol` resolves to the respective context parameter of the anonymous function.
509
+ // @C should only be collected if `boundSymbol` resolves to the receiver parameter of the anonymous function.
505
510
506
511
// Collect the annotation on the function type, or `@A` in the example above.
507
- collectDslMarkerAnnotations(context, it .customAnnotations)
512
+ collectDslMarkerAnnotations(context, matchingParameterFunctionType .customAnnotations)
508
513
509
- // Collect the annotation on the extension receiver, or `@B` in the example above.
510
- it.receiverType(context.session)?.let { receiverType ->
511
- collectDslMarkerAnnotations(context, receiverType)
514
+ // Collect the annotation on the context parameter, or `@B` in the example above.
515
+ if (boundSymbol is FirValueParameterSymbol ) {
516
+ val index = anonymousFunctionSymbol.contextParameterSymbols.indexOf(boundSymbol)
517
+ matchingParameterFunctionType.contextParameterTypes(context.session).elementAtOrNull(index)?.let { contextType ->
518
+ collectDslMarkerAnnotations(context, contextType)
519
+ }
512
520
}
513
521
514
- it.contextParameterTypes(context.session).forEach { contextType ->
515
- collectDslMarkerAnnotations(context, contextType)
522
+ // Collect the annotation on the extension receiver, or `@C` in the example above.
523
+ if (boundSymbol is FirReceiverParameterSymbol ) {
524
+ matchingParameterFunctionType.receiverType(context.session)?.let { receiverType ->
525
+ collectDslMarkerAnnotations(context, receiverType)
526
+ }
516
527
}
517
528
}
518
529
0 commit comments