Skip to content

Commit 8308b33

Browse files
committed
Implement @unboxed annotation exemption for reach capabilities of parameters
1 parent 801d0d3 commit 8308b33

File tree

4 files changed

+104
-11
lines changed

4 files changed

+104
-11
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ object ccConfig:
4040
*/
4141
inline val handleEtaExpansionsSpecially = false
4242

43+
val useUnboxedParams = true
44+
4345
/** If true, use existential capture set variables */
4446
def useExistentials(using Context) =
4547
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.5`)

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ object CheckCaptures:
6262
val res = cur
6363
cur = cur.outer
6464
res
65+
66+
def ownerString(using Context): String =
67+
if owner.isAnonymousFunction then "enclosing function" else owner.show
6568
end Env
6669

6770
/** Similar normal substParams, but this is an approximating type map that
@@ -386,21 +389,68 @@ class CheckCaptures extends Recheck, SymTransformer:
386389
val included = cs.filter: c =>
387390
c.stripReach match
388391
case ref: TermRef =>
389-
val isVisible = isVisibleFromEnv(ref.symbol.owner)
390-
if !isVisible && c.isReach then
392+
//if c.isReach then println(i"REACH $c in ${env.owner}")
393+
//assert(!env.owner.isAnonymousFunction)
394+
val refSym = ref.symbol
395+
val refOwner = refSym.owner
396+
val isVisible = isVisibleFromEnv(refOwner)
397+
if !isVisible && c.isReach && refSym.is(Param) && refOwner == env.owner then
398+
if refSym.hasAnnotation(defn.UnboxedAnnot) then
399+
capt.println(i"exempt: $ref in $refOwner")
400+
else
391401
// Reach capabilities that go out of scope have to be approximated
392-
// by their underlyiong capture set. See i20503.scala.
393-
checkSubset(CaptureSet.ofInfo(c), env.captured, pos, provenance(env))
402+
// by their underlying capture set, which cannot be universal.
403+
// Reach capabilities of @unboxed parameters are exempted.
404+
val cs = CaptureSet.ofInfo(c)
405+
if ccConfig.useUnboxedParams then
406+
cs.disallowRootCapability: () =>
407+
report.error(em"Local reach capability $c leaks into capture scope of ${env.ownerString}", pos)
408+
checkSubset(cs, env.captured, pos, provenance(env))
394409
isVisible
395410
case ref: ThisType => isVisibleFromEnv(ref.cls)
396411
case _ => false
397-
capt.println(i"Include call or box capture $included from $cs in ${env.owner}")
398412
checkSubset(included, env.captured, pos, provenance(env))
413+
capt.println(i"Include call or box capture $included from $cs in ${env.owner} --> ${env.captured}")
414+
end markFree
399415

400416
/** Include references captured by the called method in the current environment stack */
401417
def includeCallCaptures(sym: Symbol, pos: SrcPos)(using Context): Unit =
402418
if sym.exists && curEnv.isOpen then markFree(capturedVars(sym), pos)
403419

420+
private val prefixCalls = util.EqHashSet[GenericApply]()
421+
private val unboxedArgs = util.EqHashSet[Tree]()
422+
423+
def handleCall(meth: Symbol, call: GenericApply, eval: () => Type)(using Context): Type =
424+
if prefixCalls.remove(call) then return eval()
425+
426+
val unboxedParamNames =
427+
meth.rawParamss.flatMap: params =>
428+
params.collect:
429+
case param if param.hasAnnotation(defn.UnboxedAnnot) =>
430+
param.name
431+
.toSet
432+
433+
def markUnboxedArgs(call: GenericApply): Unit = call.fun.tpe.widen match
434+
case MethodType(pnames) =>
435+
for (pname, arg) <- pnames.lazyZip(call.args) do
436+
if unboxedParamNames.contains(pname) then
437+
unboxedArgs.add(arg)
438+
case _ =>
439+
440+
def markPrefixCalls(tree: Tree): Unit = tree match
441+
case tree: GenericApply =>
442+
prefixCalls.add(tree)
443+
markUnboxedArgs(tree)
444+
markPrefixCalls(tree.fun)
445+
case _ =>
446+
447+
markUnboxedArgs(call)
448+
markPrefixCalls(call.fun)
449+
val res = eval()
450+
includeCallCaptures(meth, call.srcPos)
451+
res
452+
end handleCall
453+
404454
override def recheckIdent(tree: Ident, pt: Type)(using Context): Type =
405455
if tree.symbol.is(Method) then
406456
if tree.symbol.info.isParameterless then
@@ -470,7 +520,6 @@ class CheckCaptures extends Recheck, SymTransformer:
470520
*/
471521
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
472522
val meth = tree.fun.symbol
473-
includeCallCaptures(meth, tree.srcPos)
474523

475524
// Unsafe box/unbox handlng, only for versions < 3.3
476525
def mapArgUsing(f: Type => Type) =
@@ -503,7 +552,7 @@ class CheckCaptures extends Recheck, SymTransformer:
503552
tp.derivedCapturingType(forceBox(parent), refs)
504553
mapArgUsing(forceBox)
505554
else
506-
Existential.toCap(super.recheckApply(tree, pt)) match
555+
handleCall(meth, tree, () => Existential.toCap(super.recheckApply(tree, pt))) match
507556
case appType @ CapturingType(appType1, refs) =>
508557
tree.fun match
509558
case Select(qual, _)
@@ -521,6 +570,13 @@ class CheckCaptures extends Recheck, SymTransformer:
521570
case appType => appType
522571
end recheckApply
523572

573+
override def recheckArg(arg: Tree, formal: Type)(using Context): Type =
574+
val argType = recheck(arg, formal)
575+
if unboxedArgs.remove(arg) && ccConfig.useUnboxedParams then
576+
capt.println(i"charging deep capture set of $arg: ${argType} = ${CaptureSet.deepCaptureSet(argType)}")
577+
markFree(CaptureSet.deepCaptureSet(argType), arg.srcPos)
578+
argType
579+
524580
private def isDistinct(xs: List[Type]): Boolean = xs match
525581
case x :: xs1 => xs1.isEmpty || !xs1.contains(x) && isDistinct(xs1)
526582
case Nil => true
@@ -589,20 +645,21 @@ class CheckCaptures extends Recheck, SymTransformer:
589645
end instantiate
590646

591647
override def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type =
648+
val meth = tree.symbol
592649
if ccConfig.useSealed then
593650
val TypeApply(fn, args) = tree
594651
val polyType = atPhase(thisPhase.prev):
595652
fn.tpe.widen.asInstanceOf[TypeLambda]
596653
def isExempt(sym: Symbol) =
597654
sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
598655
for case (arg: TypeTree, formal, pname) <- args.lazyZip(polyType.paramRefs).lazyZip((polyType.paramNames)) do
599-
if !isExempt(tree.symbol) then
600-
def where = if fn.symbol.exists then i" in an argument of ${fn.symbol}" else ""
656+
if !isExempt(meth) then
657+
def where = if meth.exists then i" in an argument of $meth" else ""
601658
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
602659
i"Sealed type variable $pname", "be instantiated to",
603660
i"This is often caused by a local capability$where\nleaking as part of its result.",
604661
tree.srcPos)
605-
Existential.toCap(super.recheckTypeApply(tree, pt))
662+
handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
606663

607664
override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
608665
inNestedLevel(super.recheckBlock(tree, pt))

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ abstract class Recheck extends Phase, SymTransformer:
289289
/** A hook to massage the type of an applied method; currently not overridden */
290290
protected def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType = funtpe
291291

292+
protected def recheckArg(arg: Tree, formal: Type)(using Context): Type =
293+
recheck(arg, formal)
294+
292295
def recheckApply(tree: Apply, pt: Type)(using Context): Type =
293296
val funtpe0 = recheck(tree.fun)
294297
// reuse the tree's type on signature polymorphic methods, instead of using the (wrong) rechecked one
@@ -303,7 +306,7 @@ abstract class Recheck extends Phase, SymTransformer:
303306
else fntpe.paramInfos
304307
def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match
305308
case arg :: args1 =>
306-
val argType = recheck(arg, normalizeByName(formals.head))
309+
val argType = recheckArg(arg, normalizeByName(formals.head))
307310
val formals1 =
308311
if fntpe.isParamDependent
309312
then formals.tail.map(_.substParam(prefs.head, argType))

tests/neg/leak-problem.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import language.experimental.captureChecking
2+
3+
// Some capabilities that should be used locally
4+
trait Async:
5+
// some method
6+
def read(): Unit
7+
def usingAsync[X](op: Async^ => X): X = ???
8+
9+
case class Box[+T](get: T)
10+
11+
def useBoxedAsync(x: Box[Async^]): Unit =
12+
val t0 = x
13+
val t1 = t0.get // error
14+
t1.read()
15+
16+
def useBoxedAsync1(x: Box[Async^]): Unit = x.get.read() // error
17+
18+
def test(): Unit =
19+
val useBoxedAsync2 = (x: Box[Async^]) =>
20+
val t0 = x
21+
val t1 = x.get // error
22+
t1.read()
23+
24+
val f: Box[Async^] => Unit = (x: Box[Async^]) => useBoxedAsync(x)
25+
26+
def boom(x: Async^): () ->{f} Unit =
27+
() => f(Box(x))
28+
29+
val leaked = usingAsync[() ->{f} Unit](boom)
30+
31+
leaked() // scope violation

0 commit comments

Comments
 (0)