Skip to content

Make erased capability-safe #23419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,13 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
case New(_) | Closure(_, _, _) =>
Pure
case TypeApply(fn, _) =>
val sym = fn.symbol
if tree.tpe.isInstanceOf[MethodOrPoly] then exprPurity(fn)
else if fn.symbol == defn.QuotedTypeModule_of || fn.symbol == defn.Predef_classOf then Pure
else if fn.symbol == defn.Compiletime_erasedValue && tree.tpe.dealias.isInstanceOf[ConstantType] then Pure
else if sym == defn.QuotedTypeModule_of
|| sym == defn.Predef_classOf
|| sym == defn.Compiletime_erasedValue && tree.tpe.dealias.isInstanceOf[ConstantType]
|| defn.capsErasedValueMethods.contains(sym)
then Pure
else Impure
case Apply(fn, args) =>
val factorPurity = minOf(exprPurity(fn), args.map(exprPurity))
Expand Down Expand Up @@ -634,6 +638,15 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

def isPureBinding(tree: Tree)(using Context): Boolean = statPurity(tree) >= Pure

def isPureSyntheticCaseApply(sym: Symbol)(using Context): Boolean =
sym.isAllOf(SyntheticMethod)
&& sym.name == nme.apply
&& sym.owner.is(Module)
&& {
val cls = sym.owner.companionClass
cls.is(Case) && cls.isNoInitsRealClass
}

/** Is the application `tree` with function part `fn` known to be pure?
* Function value and arguments can still be impure.
*/
Expand All @@ -645,6 +658,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>

tree.tpe.isInstanceOf[ConstantType] && tree.symbol != NoSymbol && isKnownPureOp(tree.symbol) // A constant expression with pure arguments is pure.
|| fn.symbol.isStableMember && fn.symbol.isConstructor // constructors of no-inits classes are stable
|| isPureSyntheticCaseApply(fn.symbol)

/** The purity level of this reference.
* @return
Expand All @@ -653,8 +667,6 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
* or its type is a constant type
* IdempotentPath if reference is lazy and stable
* Impure otherwise
* @DarkDimius: need to make sure that lazy accessor methods have Lazy and Stable
* flags set.
*/
def refPurity(tree: Tree)(using Context): PurityLevel = {
val sym = tree.symbol
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
case nil =>
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.erasedParams).map(valueParam), Nil)
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.paramErasureStatuses).map(valueParam), Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
(rtp, vparams :: paramss)
case _ =>
Expand Down
8 changes: 3 additions & 5 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -495,15 +495,13 @@ extension (sym: Symbol)

/** Does this symbol allow results carrying the universal capability?
* Currently this is true only for function type applies (since their
* results are unboxed) and `erasedValue` since this function is magic in
* that is allows to conjure global capabilies from nothing (aside: can we find a
* more controlled way to achieve this?).
* results are unboxed) and `caps.{$internal,unsafe}.erasedValue` since
* these function are magic in that they allow to conjure global capabilies from nothing.
* But it could be generalized to other functions that so that they can take capability
* classes as arguments.
*/
def allowsRootCapture(using Context): Boolean =
sym == defn.Compiletime_erasedValue
|| defn.isFunctionClass(sym.maybeOwner)
defn.capsErasedValueMethods.contains(sym) || defn.isFunctionClass(sym.maybeOwner)

/** When applying `sym`, would the result type be unboxed?
* This is the case if the result type contains a top-level reference to an enclosing
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class CheckCaptures extends Recheck, SymTransformer:
* @param args the type arguments
*/
def disallowCapInTypeArgs(fn: Tree, sym: Symbol, args: List[Tree])(using Context): Unit =
def isExempt = sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
def isExempt = sym.isTypeTestOrCast || defn.capsErasedValueMethods.contains(sym)
if !isExempt then
val paramNames = atPhase(thisPhase.prev):
fn.tpe.widenDealias match
Expand Down
6 changes: 2 additions & 4 deletions compiler/src/dotty/tools/dotc/core/CheckRealizable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ object CheckRealizable {

def boundsRealizability(tp: Type)(using Context): Realizability =
new CheckRealizable().boundsRealizability(tp)

private val LateInitializedFlags = Lazy | Erased
}

/** Compute realizability status.
Expand All @@ -72,7 +70,7 @@ class CheckRealizable(using Context) {
/** Is symbol's definitition a lazy or erased val?
* (note we exclude modules here, because their realizability is ensured separately)
*/
private def isLateInitialized(sym: Symbol) = sym.isOneOf(LateInitializedFlags, butNot = Module)
private def isLateInitialized(sym: Symbol) = sym.is(Lazy, butNot = Module)

/** The realizability status of given type `tp`*/
def realizability(tp: Type): Realizability = tp.dealias match {
Expand Down Expand Up @@ -184,7 +182,7 @@ class CheckRealizable(using Context) {
private def memberRealizability(tp: Type) = {
def checkField(sofar: Realizability, fld: SingleDenotation): Realizability =
sofar andAlso {
if (checkedFields.contains(fld.symbol) || fld.symbol.isOneOf(Private | Mutable | LateInitializedFlags))
if (checkedFields.contains(fld.symbol) || fld.symbol.isOneOf(Private | Mutable | Lazy))
// if field is private it cannot be part of a visible path
// if field is mutable it cannot be part of a path
// if field is lazy or erased it does not need to be initialized when the owning object is
Expand Down
11 changes: 10 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1004,9 +1004,11 @@ class Definitions {
@tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability")
@tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet")
@tu lazy val CapsInternalModule: Symbol = requiredModule("scala.caps.internal")
@tu lazy val Caps_erasedValue: Symbol = CapsInternalModule.requiredMethod("erasedValue")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
@tu lazy val Caps_unsafeErasedValue: Symbol = CapsUnsafeModule.requiredMethod("unsafeErasedValue")
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
@tu lazy val Caps_ContainsModule: Symbol = requiredModule("scala.caps.Contains")
@tu lazy val Caps_containsImpl: TermSymbol = Caps_ContainsModule.requiredMethod("containsImpl")
Expand Down Expand Up @@ -1558,6 +1560,11 @@ class Definitions {
@tu lazy val pureSimpleClasses =
Set(StringClass, NothingClass, NullClass) ++ ScalaValueClasses()

@tu lazy val capsErasedValueMethods =
Set(Caps_erasedValue, Caps_unsafeErasedValue)
@tu lazy val erasedValueMethods =
capsErasedValueMethods + Compiletime_erasedValue

@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0).asInstanceOf[Array[TypeRef]]
val AbstractFunctionClassPerRun: PerRun[Array[Symbol]] = new PerRun(AbstractFunctionType.map(_.symbol.asClass))
def AbstractFunctionClass(n: Int)(using Context): Symbol = AbstractFunctionClassPerRun()(using ctx)(n)
Expand Down Expand Up @@ -2001,7 +2008,9 @@ class Definitions {

/** A allowlist of Scala-2 classes that are known to be pure */
def isAssuredNoInits(sym: Symbol): Boolean =
(sym `eq` SomeClass) || isTupleClass(sym)
(sym `eq` SomeClass)
|| isTupleClass(sym)
|| sym.is(Module) && isAssuredNoInits(sym.companionClass)

/** If `cls` is Tuple1..Tuple22, add the corresponding *: type as last parent to `parents` */
def adjustForTuple(cls: ClassSymbol, tparams: List[TypeSymbol], parents: List[Type]): List[Type] = {
Expand Down
1 change: 0 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ object Flags {
val EnumCase: FlagSet = Case | Enum
val CovariantLocal: FlagSet = Covariant | Local // A covariant type parameter
val ContravariantLocal: FlagSet = Contravariant | Local // A contravariant type parameter
val EffectivelyErased = PhantomSymbol | Erased
val ConstructorProxyModule: FlagSet = PhantomSymbol | Module
val CaptureParam: FlagSet = PhantomSymbol | StableRealizable | Synthetic
val DefaultParameter: FlagSet = HasDefault | Param // A Scala 2x default parameter
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1051,8 +1051,13 @@ object SymDenotations {
&& owner.ne(defn.StringContextClass)

/** An erased value or an erased inline method or field */
def isErased(using Context): Boolean =
is(Erased) || defn.erasedValueMethods.contains(symbol)

/** An erased value, a phantom symbol or an erased inline method or field */
def isEffectivelyErased(using Context): Boolean =
isOneOf(EffectivelyErased)
isErased
|| is(PhantomSymbol)
|| is(Inline) && !isRetainedInline && !hasAnnotation(defn.ScalaStaticAnnot)

/** Is this a member that will become public in the generated binary */
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2387,7 +2387,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
formals2.isEmpty
}
// If methods have erased parameters, then the erased parameters must match
val erasedValid = (!tp1.hasErasedParams && !tp2.hasErasedParams) || (tp1.erasedParams == tp2.erasedParams)
val erasedValid = (!tp1.hasErasedParams && !tp2.hasErasedParams) || (tp1.paramErasureStatuses == tp2.paramErasureStatuses)

erasedValid && loop(tp1.paramInfos, tp2.paramInfos)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
val (names, formals0) = if tp.hasErasedParams then
tp.paramNames
.zip(tp.paramInfos)
.zip(tp.erasedParams)
.zip(tp.paramErasureStatuses)
.collect{ case (param, isErased) if !isErased => param }
.unzip
else (tp.paramNames, tp.paramInfos)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3931,7 +3931,7 @@ object Types extends TypeUtils {
case tp: MethodType =>
val params = if (hasErasedParams)
tp.paramInfos
.zip(tp.erasedParams)
.zip(tp.paramErasureStatuses)
.collect { case (param, isErased) if !isErased => param }
else tp.paramInfos
resultSignature.prependTermParams(params, sourceLanguage)
Expand Down Expand Up @@ -4163,7 +4163,7 @@ object Types extends TypeUtils {
final override def isContextualMethod: Boolean =
companion.eq(ContextualMethodType)

def erasedParams(using Context): List[Boolean] =
def paramErasureStatuses(using Context): List[Boolean] =
paramInfos.map(p => p.hasAnnotation(defn.ErasedParamAnnot))

def nonErasedParamCount(using Context): Int =
Expand Down
54 changes: 50 additions & 4 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ class InlineReducer(inliner: Inliner)(using Context):

val isImplicit = scrutinee.isEmpty

val unusable: util.EqHashSet[Symbol] = util.EqHashSet()

/** Adjust internaly generated value definitions;
* - If the RHS refers to an erased symbol, mark the val as erased
* - If the RHS refers to an erased symbol, mark the val as unsuable
*/
def adjustErased(sym: TermSymbol, rhs: Tree): Unit =
rhs.foreachSubTree:
case id: Ident if id.symbol.isErased =>
sym.setFlag(Erased)
if unusable.contains(id.symbol) then unusable += sym
case _ =>

/** Try to match pattern `pat` against scrutinee reference `scrut`. If successful add
* bindings for variables bound in this pattern to `caseBindingMap`.
*/
Expand All @@ -184,10 +197,11 @@ class InlineReducer(inliner: Inliner)(using Context):
/** Create a binding of a pattern bound variable with matching part of
* scrutinee as RHS and type that corresponds to RHS.
*/
def newTermBinding(sym: TermSymbol, rhs: Tree): Unit = {
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord, flags = sym.flags &~ Case).asTerm
def newTermBinding(sym: TermSymbol, rhs: Tree): Unit =
val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord,
flags = sym.flags &~ Case).asTerm
adjustErased(copied, rhs)
caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span)))
}

def newTypeBinding(sym: TypeSymbol, alias: Type): Unit = {
val copied = sym.copy(info = TypeAlias(alias), coord = sym.coord).asType
Expand Down Expand Up @@ -306,6 +320,7 @@ class InlineReducer(inliner: Inliner)(using Context):
case (Nil, Nil) => true
case (pat :: pats1, selector :: selectors1) =>
val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm
adjustErased(elem, selector)
val rhs = constToLiteral(selector)
elem.defTree = rhs
caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span)))
Expand Down Expand Up @@ -341,6 +356,19 @@ class InlineReducer(inliner: Inliner)(using Context):
val scrutineeSym = newSym(InlineScrutineeName.fresh(), Synthetic, scrutType).asTerm
val scrutineeBinding = normalizeBinding(ValDef(scrutineeSym, scrutinee))

// If scrutinee has embedded references to `compiletime.erasedValue` or to
// other erased values, mark scrutineeSym as Erased. In addition, if scrutinee
// is not a pure expression, mark scrutineeSym as unusable. The reason is that
// scrutinee would then fail the tests in erasure that demand that the RHS of
// an erased val is a pure expression. At the end of the inline match reduction
// we throw out all unusable vals and check that the remaining code does not refer
// to unusable symbols.
// Note that compiletime.erasedValue is treated as erased but not pure, so scrutinees
// containing references to it becomes unusable.
if scrutinee.existsSubTree(_.symbol.isErased) then
scrutineeSym.setFlag(Erased)
if !tpd.isPureExpr(scrutinee) then unusable += scrutineeSym

def reduceCase(cdef: CaseDef): MatchReduxWithGuard = {
val caseBindingMap = new mutable.ListBuffer[(Symbol, MemberDef)]()

Expand Down Expand Up @@ -382,7 +410,25 @@ class InlineReducer(inliner: Inliner)(using Context):
case _ => None
}

recur(cases)
for (bindings, expr) <- recur(cases) yield
// drop unusable vals and check that no referenes to unusable symbols remain
val cleanupUnusable = new TreeMap:
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: ValDef if unusable.contains(tree.symbol) => EmptyTree
case id: Ident if unusable.contains(id.symbol) =>
report.error(
em"""${id.symbol} is unusable in ${ctx.owner} because it refers to an erased expression
|in the selector of an inline match that reduces to
|
|${Block(bindings, expr)}""",
tree.srcPos)
tree
case _ => super.transform(tree)

val bindings1 = bindings.mapConserve(cleanupUnusable.transform).collect:
case mdef: MemberDef => mdef
(bindings1, cleanupUnusable.transform(expr))
}
end InlineReducer

4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/quoted/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context):
case fnType: MethodType =>
val argTypes = fnType.paramInfos
assert(argss.head.size == argTypes.size)
val nonErasedArgs = argss.head.lazyZip(fnType.erasedParams).collect { case (arg, false) => arg }.toList
val nonErasedArgTypes = fnType.paramInfos.lazyZip(fnType.erasedParams).collect { case (arg, false) => arg }.toList
val nonErasedArgs = argss.head.lazyZip(fnType.paramErasureStatuses).collect { case (arg, false) => arg }.toList
val nonErasedArgTypes = fnType.paramInfos.lazyZip(fnType.paramErasureStatuses).collect { case (arg, false) => arg }.toList
assert(nonErasedArgs.size == nonErasedArgTypes.size)
interpretArgsGroup(nonErasedArgs, nonErasedArgTypes) ::: interpretArgs(argss.tail, fnType.resType)
case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ enum ErrorMessageID(val isActive: Boolean = true) extends java.lang.Enum[ErrorMe
case PointlessAppliedConstructorTypeID // errorNumber: 213
case IllegalContextBoundsID // errorNumber: 214
case NamedPatternNotApplicableID // errorNumber: 215
case ErasedNotPureID // errornumber 216

def errorNumber = ordinal - 1

Expand Down
34 changes: 33 additions & 1 deletion compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3538,4 +3538,36 @@ final class NamedPatternNotApplicable(selectorType: Type)(using Context) extends
override protected def msg(using Context): String =
i"Named patterns cannot be used with $selectorType, because it is not a named tuple or case class"

override protected def explain(using Context): String = ""
override protected def explain(using Context): String = ""

final class ErasedNotPure(tree: tpd.Tree, isArgument: Boolean, isImplicit: Boolean)(using Context) extends TypeMsg(ErasedNotPureID):
def what =
if isArgument then s"${if isImplicit then "implicit " else ""}argument to an erased parameter"
else "right-hand-side of an erased value"
override protected def msg(using Context): String =
i"$what fails to be a pure expression"

override protected def explain(using Context): String =
def alternatives =
if tree.symbol == defn.Compiletime_erasedValue then
i"""An accepted (but unsafe) alternative for this expression uses function
|
| caps.unsafe.unsafeErasedValue
|
|instead."""
else
"""A pure expression is an expression that is clearly side-effect free and terminating.
|Some examples of pure expressions are:
| - literals,
| - references to values,
| - side-effect-free instance creations,
| - applications of inline functions to pure arguments."""

i"""The $what must be a pure expression, but I found:
|
| $tree
|
|This expression is not classified to be pure.
|$alternatives"""
end ErasedNotPure

Loading
Loading