Skip to content

Inline traits for specialization in Scala 3 (v2) #20254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP inner traits
  • Loading branch information
nicolasstucki committed Apr 26, 2024
commit 54e69595873b11c95e9bdebcdc43a50fa455ac01
144 changes: 100 additions & 44 deletions compiler/src/dotty/tools/dotc/inlines/InlineTraits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.NameOps.*
import dotty.tools.dotc.core.Names.TypeName
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.Scopes.newScope
import dotty.tools.dotc.report
import dotty.tools.dotc.util.SrcPos
import dotty.tools.dotc.util.Spans.Span
Expand All @@ -29,70 +31,124 @@ object InlineTraits:
*/
def inlinedMemberSymbols(cls: ClassSymbol)(using Context): List[Symbol] =
assert(!cls.isInlineTrait, cls)
val parents = cls.info.parents

def parentTargs(inlinableDecl: Symbol): List[Type] =
val baseClass = inlinableDecl.owner.asClass
mixinParentTypeOf(cls, baseClass).baseType(baseClass) match
case AppliedType(_, targs) => targs
case _ => Nil

def inlinedSymbol(inlinableDecl: Symbol, traitTargs: List[Type]): Symbol =
val flags = inlinableDecl.flags | Override | Synthetic
val info = inlinableDecl.info
.substThis(inlinableDecl.owner.asClass, ThisType.raw(cls.typeRef))
.subst(inlinableDecl.owner.typeParams, traitTargs)
val privateWithin = inlinableDecl.privateWithin // TODO what should `privateWithin` be?
newSymbol(cls, inlinableDecl.name, flags, info, privateWithin, cls.span)

def needsInlinedDecl(sym: Symbol): Boolean =
sym.isTerm && !sym.isConstructor && !sym.is(ParamAccessor)
&& sym.owner.isInlineTrait

for
denot <- cls.typeRef.allMembers.toList
inlinableDecl = denot.symbol
if needsInlinedDecl(inlinableDecl)
sym = denot.symbol
if isInlinableMember(sym)
yield
inlinedSymbol(inlinableDecl, parentTargs(inlinableDecl))

val traitTargs = parentTargs(cls, sym)
if sym.isClass then inlinedSymbolClassDef(cls, sym.asClass, traitTargs)
else inlinedSymbolValOrDef(cls, sym, traitTargs)
end inlinedMemberSymbols

private def isInlinableMember(sym: Symbol)(using Context): Boolean =
(sym.isTerm || sym.isClass)
&& !sym.isConstructor && !sym.is(ParamAccessor)
&& sym.owner.isInlineTrait

private def parentTargs(cls: ClassSymbol, inlinableDecl: Symbol)(using Context): List[Type] =
val baseClass = inlinableDecl.owner.asClass
mixinParentTypeOf(cls, baseClass).baseType(baseClass) match
case AppliedType(_, targs) => targs
case _ => Nil

private def inlinedSymbolValOrDef(cls: ClassSymbol, inlinableDecl: Symbol, traitTargs: List[Type])(using Context): Symbol =
val flags = inlinableDecl.flags | Override | Synthetic
val info = inlinableDecl.info
.substThis(inlinableDecl.owner.asClass, ThisType.raw(cls.typeRef))
.subst(inlinableDecl.owner.typeParams, traitTargs)
val privateWithin = inlinableDecl.privateWithin // TODO what should `privateWithin` be?
newSymbol(cls, inlinableDecl.name, flags, info, privateWithin, cls.span)

private def inlinedSymbolClassDef(cls: ClassSymbol, inlinableDecl: ClassSymbol, traitTargs: List[Type])(using Context): ClassSymbol =
def infoFn(cls1: ClassSymbol) =
inlinableDecl.info.asInstanceOf[ClassInfo].derivedClassInfo(
prefix = cls.typeRef,
declaredParents = defn.ObjectType :: cls.thisType.select(inlinableDecl) :: Nil,
decls = newScope,
// selfInfo = ,
)

val newCls = newClassSymbol(
owner = cls,
name = inlinableDecl.name.toTypeName,
flags = inlinableDecl.flags | Synthetic,
infoFn = infoFn,
privateWithin = NoSymbol,
coord = cls.coord
)

newConstructor(
newCls,
flags = EmptyFlags,
paramNames = Nil,
paramTypes = Nil,
privateWithin = NoSymbol,
coord = newCls.coord
).entered

def inlinedDefs(cls: ClassSymbol)(using Context): List[Tree] =
def makeValOrDef(inlinedDecl: Symbol): Tree =
def makeSuperSelect(using Context) =
ctx.compilationUnit.needsInlining = true
val inlinableDecl = inlinedDecl.allOverriddenSymbols.find { sym =>
!sym.is(Deferred) && sym.owner.isInlineTrait
}.getOrElse(inlinedDecl.nextOverriddenSymbol)
val parent = mixinParentTypeOf(cls, inlinableDecl.owner.asClass)
Super(This(cls), parent.typeSymbol.name.asTypeName).select(inlinableDecl)

def rhs(using Context)(argss: List[List[Tree]]) =
if inlinedDecl.is(Deferred) then EmptyTree
else if inlinedDecl.is(Mutable) && inlinedDecl.name.isSetterName then Literal(Constant(()))
else makeSuperSelect.appliedToArgss(argss)
for
decl <- inlinableDecl.info.decls.toList
if decl.isTerm && !decl.isConstructor && !decl.is(ParamAccessor)
do
inlinedSymbolValOrDef(newCls, decl, traitTargs).entered

if inlinedDecl.is(Method) then DefDef(inlinedDecl.asTerm, rhs(using ctx.withOwner(inlinedDecl))).withSpan(cls.span)
else ValDef(inlinedDecl.asTerm, rhs(using ctx.withOwner(inlinedDecl))(Nil)).withSpan(cls.span)
newCls
end inlinedSymbolClassDef

def inlinedDefs(cls: ClassSymbol)(using Context): List[Tree] =
atPhase(ctx.phase.next) { cls.info.decls.toList }
.filter(sym => sym.is(Synthetic) && sym.nextOverriddenSymbol.maybeOwner.isInlineTrait)
.map(makeValOrDef)
.map { sym =>
if sym.isClass then inlinedClassDefs(cls, sym.asClass)
else inlinedValOrDefDefs(cls, sym)
}

private def inlinedValOrDefDefs(cls: ClassSymbol, inlinedDecl: Symbol)(using Context): Tree =
val inlinableDecl = inlinedDecl.allOverriddenSymbols.find { sym =>
!sym.is(Deferred) && sym.owner.isInlineTrait
}.getOrElse(inlinedDecl.nextOverriddenSymbol)
val parent = mixinParentTypeOf(inlinedDecl.owner.asClass, inlinableDecl.owner.asClass).typeSymbol.name.asTypeName
valOrDefDefInlineOverride(cls, inlinedDecl, parent, inlinableDecl)

private def inlinedClassDefs(cls: ClassSymbol, inlinedDecl: ClassSymbol)(using Context): Tree =
val parent = inlinedDecl.info.parents.last.typeSymbol
val members = parent.info.decls.toList.filterNot(_.is(ParamAccessor)).zip(inlinedDecl.info.decls.toList).collect {
case (overridden, decl) if decl.isTerm && !decl.isConstructor && !decl.is(ParamAccessor) =>
assert(overridden.name == decl.name, (overridden, decl)) // TODO find better wy to recover `overridden` from `decl`
val parent = mixinParentTypeOf(decl.owner.asClass, overridden.owner.asClass).typeSymbol.name.asTypeName
valOrDefDefInlineOverride(cls, decl, parent, overridden)
}
ClassDef(
inlinedDecl.asClass,
DefDef(inlinedDecl.primaryConstructor.asTerm),
body = members,
superArgs = List.empty[Tree]
).withSpan(inlinedDecl.span)

private def valOrDefDefInlineOverride(cls: ClassSymbol, decl: Symbol, parent: TypeName, overridden: Symbol)(using Context): Tree =
def rhs(argss: List[List[Tree]])(using Context) =
if decl.is(Deferred) then EmptyTree
else if decl.is(Mutable) && decl.name.isSetterName then Literal(Constant(()))
else
ctx.compilationUnit.needsInlining = true
Super(This(ctx.owner.owner.asClass), parent).select(overridden).appliedToArgss(argss)

end inlinedDefs
if decl.is(Method) then DefDef(decl.asTerm, rhs(_)(using ctx.withOwner(decl))).withSpan(cls.span)
else ValDef(decl.asTerm, rhs(Nil)(using ctx.withOwner(decl))).withSpan(cls.span)

private def mixinParentTypeOf(cls: ClassSymbol, baseClass: ClassSymbol)(using Context): Type =
cls.info.parents.findLast(parent => parent.typeSymbol.derivesFrom(baseClass)).get

/** Register inline members RHS in `@bodyAnnotation`s */
def registerInlineTraitInfo(cls: ClassSymbol, stats: List[Tree])(using Context): Unit =
def registerInlineTraitInfo(stats: List[Tree])(using Context): Unit =
for stat <- stats do
stat match
case stat: ValOrDefDef if !stat.symbol.is(Inline) && !stat.symbol.is(Deferred) =>
// TODO? val rhsToInline = PrepareInlineable.wrapRHS(stat, stat.tpt, stat.rhs)
PrepareInlineable.registerInlineInfo(stat.symbol, stat.rhs/*TODO? rhsToInline*/)
case TypeDef(_, rhs: Template) =>
registerInlineTraitInfo(rhs.body)
case _ =>

/** Checks if members are supported in inline traits */
Expand All @@ -105,7 +161,7 @@ object InlineTraits:
else if sym.is(Private) then report.error(em"Implementation restriction: private ${sym.kindString} cannot be defined in inline traits", stat.srcPos)
else () // Ok
case stat: TypeDef =>
if sym.isClass then report.error(em"Implementation restriction: ${sym.kindString} cannot be defined in inline traits", stat.srcPos)
if sym.isClass && !sym.is(Trait) then report.error(em"Implementation restriction: ${sym.kindString} cannot be defined in inline traits", stat.srcPos)
else () // OK
case _: Import =>
report.error(em"Implementation restriction: import cannot be defined in inline traits", stat.srcPos)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/inlines/Inlines.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object Inlines:
/** Can a call to method `meth` be inlined? */
def isInlineable(meth: Symbol)(using Context): Boolean =
def isSuperCallInInlineTraitGeneratedMethod =
meth.maybeOwner.isInlineTrait
meth.ownersIterator.exists(_.isInlineTrait)
&& ctx.owner.is(Synthetic)
&& ctx.owner.owner.isClass
&& ctx.owner.overriddenSymbol(meth.owner.asClass) == meth
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3924,7 +3924,7 @@ object Parsers {
}
}

/** TmplDef ::= ([‘case’] ‘class’ | [inline’] ‘trait’) ClassDef
/** TmplDef ::= ([‘case’] ‘class’ | [inline’] ‘trait’) ClassDef
* | [‘case’] ‘object’ ObjectDef
* | ‘enum’ EnumDef
* | ‘given’ GivenDef
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DeferInlineTraits extends MiniPhase with SymTransformer:
&& !sym.is(ParamAccessor)
&& !sym.is(Private)
&& !sym.isLocalDummy
&& sym.owner.isInlineTrait
&& sym.ownersIterator.exists(_.isInlineTrait)


object DeferInlineTraits:
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/ExplicitOuter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ExplicitOuter extends MiniPhase with InfoTransformer { thisPhase =>

/** Add outer accessors if a class always needs an outer pointer */
override def transformInfo(tp: Type, sym: Symbol)(using Context): Type = tp match {
case tp @ ClassInfo(_, cls, _, decls, _) if needsOuterAlways(cls) =>
case tp @ ClassInfo(_, cls, _, decls, _) if needsOuterAlways(cls) && !cls.maybeOwner.isInlineTrait =>
val newDecls = decls.cloneScope
newOuterAccessors(cls).foreach(newDecls.enter)
tp.derivedClassInfo(decls = newDecls)
Expand Down Expand Up @@ -77,7 +77,8 @@ class ExplicitOuter extends MiniPhase with InfoTransformer { thisPhase =>
ensureOuterAccessors(cls)

val clsHasOuter = hasOuter(cls)
if (clsHasOuter || cls.mixins.exists(needsOuterIfReferenced)) {
if cls.maybeOwner.isInlineTrait then impl
else if (clsHasOuter || cls.mixins.exists(needsOuterIfReferenced)) {
val newDefs = new mutable.ListBuffer[Tree]

if (clsHasOuter)
Expand Down
33 changes: 24 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/TraitInlining.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import dotty.tools.dotc.inlines.InlineTraits.*
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.NameOps.*
import dotty.tools.dotc.core.Decorators.*
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
import dotty.tools.dotc.core.DenotTransformers.DenotTransformer
import dotty.tools.dotc.core.Denotations.SingleDenotation
import dotty.tools.dotc.core.SymDenotations.SymDenotation
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.staging.StagingLevel
import dotty.tools.dotc.core.Constants.*
Expand All @@ -22,7 +24,7 @@ import scala.collection.mutable.ListBuffer
import javax.xml.transform.Templates

/** TODO */
class TraitInlining extends MacroTransform, InfoTransformer {
class TraitInlining extends MacroTransform, DenotTransformer {
self =>

import tpd.*
Expand Down Expand Up @@ -56,15 +58,28 @@ class TraitInlining extends MacroTransform, InfoTransformer {
super.transform(tree)
}

def transformInfo(tp: Type, sym: Symbol)(using Context): Type = {
tp match
case tp @ ClassInfo(_, cls, _, decls, _) if needsTraitInlining(sym.asClass) =>
val newDecls = decls.cloneScope
inlinedMemberSymbols(sym.asClass).foreach(newDecls.enter)
tp.derivedClassInfo(decls = newDecls)
def transform(ref: SingleDenotation)(using Context): SingleDenotation = {
val sym = ref.symbol
ref match {
case ref: SymDenotation if sym.isClass && !sym.is(Module) && sym.maybeOwner.isInlineTrait =>
val newName =
if sym.is(Module) then (sym.name.toString + "inline$trait$").toTypeName // TODO use NameKinds
else (sym.name.toString + "$inline$trait").toTypeName // TODO use NameKinds
ref.copySymDenotation(name = newName)
case ref: SymDenotation =>
ref.info match
case tp @ ClassInfo(_, cls, _, decls, _) if needsTraitInlining(sym.asClass) =>
val newDecls = decls.cloneScope
inlinedMemberSymbols(sym.asClass).foreach(newDecls.enter)
val newInfo = tp.derivedClassInfo(decls = newDecls)
ref.copySymDenotation(info = newInfo).copyCaches(ref, ctx.phase.next)
case _ =>
ref
case _ =>
tp
ref
}
}

}

object TraitInlining:
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2816,7 +2816,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if !ctx.isAfterTyper then
if cls.isInlineTrait then
InlineTraits.checkValidInlineTraitMember(body1)
InlineTraits.registerInlineTraitInfo(cls, body1)
InlineTraits.registerInlineTraitInfo(body1)
else
InlineTraits.adaptNoInit(cls, parents1)
// body1 = InlineTraits.inlinedMembers(cls, parents1) ::: body1
Expand Down
13 changes: 13 additions & 0 deletions tests/disabled/pos/inline-trait-3-trait-with-params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
inline trait A[T](a: T):
def f: T = a
def f(x: T): T = x
def f[U <: T](x: U, y: T): T = x
end A

class B extends A[Int](3):
/*
<generated> override def f: Int = ???
<generated> override def f(x: Int): Int = ???
<generated> override def f[U <: Int](x: U, y: Int): Int = ???
*/
end B
20 changes: 20 additions & 0 deletions tests/disabled/pos/inline-trait-4-inner-class.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
inline trait Options[+T]:
sealed trait Option:
def get: T
def isEmpty: Boolean

class Some(x: T) extends Option:
def get: T = x
def isEmpty: Boolean = false

object None extends Option:
def get: T = throw new NoSuchElementException("None.get")
def isEmpty: Boolean = true
end Options

object IntOptions extends Options[Int]
import IntOptions._

val o1: Option = Some(1) // specialized
val o2: Option = None
val x1: Int = o1.get // no unboxing
10 changes: 10 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-abstract.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
inline trait A:
class InnerA:
def foo(): Int
def bar = foo() + 1

class B extends A:
class InnerB extends InnerA:
def foo(): Int = -23

def f = InnerB().bar
5 changes: 5 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-case.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inline trait A:
case class Inner(val x: Int)

class B extends A:
def f = Inner(17).x
6 changes: 6 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
inline trait A:
enum Inner:
case A, B, C

class B extends A:
def f = Inner.B
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
inline trait A:
class Inner extends Trait[Int]:
val x = 1

inline trait Trait[T]:
def f(x: T): T = x

class B extends A:
val inner = Inner()
def x = inner.x
def f = inner.f(x)
6 changes: 6 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-generic.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
inline trait A:
class Inner[T <: Int]:
val x: T = 1

class B extends A:
def f = Inner[Int]().x
6 changes: 6 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-object.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
inline trait A[T]:
object Inner:
val x: T = ???

class B extends A[Int]:
def i: Int = Inner.x
5 changes: 5 additions & 0 deletions tests/disabled/pos/inline-trait-body-class-params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inline trait A:
class Inner(val x: Int)

class B extends A:
def f = Inner(17).x
Loading