Skip to content

Fix #21619: Refactor NotNullInfo to record every reference which is retracted once. #21624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add terminated info
  • Loading branch information
noti0na1 committed Dec 6, 2024
commit f859afe8e4ace35e026600bb784664dcbcdbda98
43 changes: 28 additions & 15 deletions compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,45 @@ object Nullables:
TypeBoundsTree(lo, hiTree, alias)

/** A set of val or var references that are known to be not null
* after the tree finishes executing normally (non-exceptionally),
* after the tree finishes executing normally (non-exceptionally),
* plus a set of variable references that are ever assigned to null,
* and may therefore be null if execution of the tree is interrupted
* by an exception.
*/
case class NotNullInfo(asserted: Set[TermRef], retracted: Set[TermRef]):
case class NotNullInfo(asserted: Set[TermRef] | Null, retracted: Set[TermRef]):
def isEmpty = this eq NotNullInfo.empty

def retractedInfo = NotNullInfo(Set(), retracted)

def terminatedInfo = NotNullInfo(null, retracted)

/** The sequential combination with another not-null info */
def seq(that: NotNullInfo): NotNullInfo =
if this.isEmpty then that
else if that.isEmpty then this
else NotNullInfo(
this.asserted.diff(that.retracted).union(that.asserted),
this.retracted.union(that.retracted))
else
val newAsserted =
if this.asserted == null || that.asserted == null then null
else this.asserted.diff(that.retracted).union(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)

/** The alternative path combination with another not-null info. Used to merge
* the nullability info of the two branches of an if.
* the nullability info of the branches of an if or match.
*/
def alt(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted.intersect(that.asserted), this.retracted.union(that.retracted))

def withRetracted(that: NotNullInfo): NotNullInfo =
NotNullInfo(this.asserted, this.retracted.union(that.retracted))
val newAsserted =
if this.asserted == null then that.asserted
else if that.asserted == null then this.asserted
else this.asserted.intersect(that.asserted)
val newRetracted = this.retracted.union(that.retracted)
NotNullInfo(newAsserted, newRetracted)
end NotNullInfo

object NotNullInfo:
val empty = new NotNullInfo(Set(), Set())
def apply(asserted: Set[TermRef], retracted: Set[TermRef]): NotNullInfo =
if asserted.isEmpty && retracted.isEmpty then empty
def apply(asserted: Set[TermRef] | Null, retracted: Set[TermRef]): NotNullInfo =
if asserted != null && asserted.isEmpty && retracted.isEmpty then empty
else new NotNullInfo(asserted, retracted)
end NotNullInfo

Expand Down Expand Up @@ -227,7 +235,7 @@ object Nullables:
*/
@tailrec def impliesNotNull(ref: TermRef): Boolean = infos match
case info :: infos1 =>
if info.asserted.contains(ref) then true
if info.asserted != null && info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else infos1.impliesNotNull(ref)
case _ =>
Expand All @@ -243,7 +251,9 @@ object Nullables:
/** Retract all references to mutable variables */
def retractMutables(using Context) =
val mutables = infos.foldLeft(Set[TermRef]()):
(ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable)))
(ms, info) => ms.union(
if info.asserted == null then Set.empty
else info.asserted.filter(_.symbol.is(Mutable)))
infos.extendWith(NotNullInfo(Set(), mutables))

end extension
Expand Down Expand Up @@ -516,7 +526,10 @@ object Nullables:
&& assignmentSpans.getOrElse(sym.span.start, Nil).exists(whileSpan.contains(_))
&& ctx.notNullInfos.impliesNotNull(ref)

val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
val retractedVars = ctx.notNullInfos.flatMap(info =>
if info.asserted == null then Set.empty
else info.asserted.filter(isRetracted)
).toSet
ctx.addNotNullInfo(NotNullInfo(Set(), retractedVars))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we lose knowledge about the loop being unreachable. The result always has asserted == Set(), even if it was null before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for the NotNullInfo in the ctx. Suppose there is a terminated info in the ctx, adding a new non-terminated info will not change its behaviour: still treating all symbols as non-nullable.

end whileContext

Expand Down
27 changes: 8 additions & 19 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1551,13 +1551,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def thenPathInfo = cond1.notNullInfoIf(true).seq(result.thenp.notNullInfo)
def elsePathInfo = cond1.notNullInfoIf(false).seq(result.elsep.notNullInfo)
result.withNotNullInfo(
if result.thenp.tpe.isNothingType then
elsePathInfo.withRetracted(thenPathInfo)
else if result.elsep.tpe.isNothingType then
thenPathInfo.withRetracted(elsePathInfo)
else thenPathInfo.alt(elsePathInfo)
)
result.withNotNullInfo(thenPathInfo.alt(elsePathInfo))
end typedIf

/** Decompose function prototype into a list of parameter prototypes and a result
Expand Down Expand Up @@ -2154,14 +2148,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
}

private def notNullInfoFromCases(initInfo: NotNullInfo, cases: List[CaseDef])(using Context): NotNullInfo =
var nnInfo = initInfo
if cases.nonEmpty then
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what instances would the cases be empty? Depending on the answer, perhaps a comment would be useful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the cases can be empty. Just want to avoid the exception when calling reduce on empty list.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it seems the empty cases can happen after inlining.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. A question is, if empty cases does happen, what does it mean in terms of runtime semantics? If it definitely fails to match, it would be more precise to return terminatedInfo, but initInfo is sound and safer in case empty cases means something different in some special situations.

I think it's fine to leave it as it is, and you can merge this PR now.

val (nothingCases, normalCases) = cases.partition(_.body.tpe.isNothingType)
nnInfo = nothingCases.foldLeft(nnInfo):
(nni, c) => nni.withRetracted(c.notNullInfo)
if normalCases.nonEmpty then
nnInfo = nnInfo.seq(normalCases.map(_.notNullInfo).reduce(_.alt(_)))
nnInfo
initInfo.seq(cases.map(_.notNullInfo).reduce(_.alt(_)))
else initInfo

def typedCases(cases: List[untpd.CaseDef], sel: Tree, wideSelType0: Type, pt: Type)(using Context): List[CaseDef] =
var caseCtx = ctx
Expand Down Expand Up @@ -2251,7 +2240,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedLabeled(tree: untpd.Labeled)(using Context): Labeled = {
val bind1 = typedBind(tree.bind, WildcardType).asInstanceOf[Bind]
val expr1 = typed(tree.expr, bind1.symbol.info)
assignType(cpy.Labeled(tree)(bind1, expr1))
assignType(cpy.Labeled(tree)(bind1, expr1)).withNotNullInfo(expr1.notNullInfo.retractedInfo)
}

/** Type a case of a type match */
Expand Down Expand Up @@ -2301,7 +2290,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
// Hence no adaptation is possible, and we assume WildcardType as prototype.
(from, proto)
val expr1 = typedExpr(tree.expr orElse untpd.syntheticUnitLiteral.withSpan(tree.span), proto)
assignType(cpy.Return(tree)(expr1, from))
assignType(cpy.Return(tree)(expr1, from)).withNotNullInfo(expr1.notNullInfo.terminatedInfo)
end typedReturn

def typedWhileDo(tree: untpd.WhileDo)(using Context): Tree =
Expand Down Expand Up @@ -2388,15 +2377,15 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
def typedThrow(tree: untpd.Throw)(using Context): Tree =
val expr1 = typed(tree.expr, defn.ThrowableType)
val cap = checkCanThrow(expr1.tpe.widen, tree.span)
val res = Throw(expr1).withSpan(tree.span)
var res = Throw(expr1).withSpan(tree.span)
if Feature.ccEnabled && !cap.isEmpty && !ctx.isAfterTyper then
// Record access to the CanThrow capabulity recovered in `cap` by wrapping
// the type of the `throw` (i.e. Nothing) in a `@requiresCapability` annotation.
Typed(res,
res = Typed(res,
TypeTree(
AnnotatedType(res.tpe,
Annotation(defn.RequiresCapabilityAnnot, cap, tree.span))))
else res
res.withNotNullInfo(expr1.notNullInfo.terminatedInfo)

def typedSeqLiteral(tree: untpd.SeqLiteral, pt: Type)(using Context): SeqLiteral = {
val elemProto = pt.stripNull().elemType match {
Expand Down