Skip to content

Commit 2bd0c96

Browse files
committed
Handle dependent refinements in class parents
1 parent 0d4dcd0 commit 2bd0c96

File tree

12 files changed

+119
-53
lines changed

12 files changed

+119
-53
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,9 @@ object desugar {
902902
}
903903
if mods.isAllOf(Given | Inline | Transparent) then
904904
report.error("inline given instances cannot be trasparent", cdef)
905-
val classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods
905+
var classMods = if mods.is(Given) then mods &~ (Inline | Transparent) | Synthetic else mods
906+
if vparamAccessors.exists(_.mods.is(Tracked)) then
907+
classMods |= Dependent
906908
cpy.TypeDef(cdef: TypeDef)(
907909
name = className,
908910
rhs = cpy.Template(impl)(constr, parents1, clsDerived, self1,

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ object Flags {
377377
/** Symbol cannot be found as a member during typer */
378378
val (Invisible @ _, _, _) = newFlags(45, "<invisible>")
379379

380-
val (Tracked @ _, _, _) = newFlags(46, "tracked")
380+
val (Tracked @ _, _, Dependent @ _) = newFlags(46, "tracked", "dependent")
381381

382382
// ------------ Flags following this one are not pickled ----------------------------------
383383

@@ -470,7 +470,7 @@ object Flags {
470470
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
471471
OuterOrCovariant, LabelOrContravariant, CaseAccessor,
472472
Extension, NonMember, Implicit, Given, Permanent, Synthetic, Exported,
473-
SuperParamAliasOrScala2x, Inline, Macro, ConstructorProxy, Invisible)
473+
SuperParamAliasOrScala2x, Inline, Macro, ConstructorProxy, Invisible, Tracked)
474474

475475
/** Flags that are not (re)set when completing the denotation, or, if symbol is
476476
* a top-level class or object, when completing the denotation once the class

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Contexts.ctx
1010
import dotty.tools.dotc.reporting.trace
1111
import config.Feature.migrateTo3
1212
import config.Printers.*
13+
import transform.TypeUtils.stripRefinement
1314

1415
trait PatternTypeConstrainer { self: TypeComparer =>
1516

@@ -88,11 +89,6 @@ trait PatternTypeConstrainer { self: TypeComparer =>
8889
}
8990
}
9091

91-
def stripRefinement(tp: Type): Type = tp match {
92-
case tp: RefinedOrRecType => stripRefinement(tp.parent)
93-
case tp => tp
94-
}
95-
9692
def tryConstrainSimplePatternType(pat: Type, scrut: Type) = {
9793
val patCls = pat.classSymbol
9894
val scrCls = scrut.classSymbol

compiler/src/dotty/tools/dotc/core/Scopes.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import Denotations.*
1717
import printing.Texts.*
1818
import printing.Printer
1919
import SymDenotations.NoDenotation
20+
import util.common.alwaysFalse
2021

2122
import collection.mutable
2223
import scala.compiletime.uninitialized
@@ -94,15 +95,13 @@ object Scopes {
9495
def foreach[U](f: Symbol => U)(using Context): Unit = toList.foreach(f)
9596

9697
/** Selects all Symbols of this Scope which satisfy a predicate. */
97-
def filter(p: Symbol => Boolean)(using Context): List[Symbol] = {
98+
def filter(p: Symbol => Boolean, stopAt: Symbol => Boolean = alwaysFalse)(using Context): List[Symbol] = {
9899
ensureComplete()
99100
var syms: List[Symbol] = Nil
100101
var e = lastEntry
101-
while ((e != null) && e.owner == this) {
102-
val sym = e.sym
103-
if (p(sym)) syms = sym :: syms
102+
while e != null && e.owner == this && !stopAt(e.sym) do
103+
if p(e.sym) then syms = e.sym :: syms
104104
e = e.prev
105-
}
106105
syms
107106
}
108107

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,7 @@ object SymDenotations {
23802380
* Both getters and setters are returned in this list.
23812381
*/
23822382
def paramAccessors(using Context): List[Symbol] =
2383-
unforcedDecls.filter(_.is(ParamAccessor))
2383+
unforcedDecls.filter(_.is(ParamAccessor))//, stopAt = sym => sym.is(Method, butNot = ParamAccessor))
23842384

23852385
/** The term parameter getters of this class. */
23862386
def paramGetters(using Context): List[Symbol] =

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import config.Printers.typr
1818
import config.Feature
1919
import util.SrcPos
2020
import reporting.*
21+
import transform.TypeUtils.stripRefinement
2122
import NameKinds.WildcardParamName
2223

2324
object PostTyper {
@@ -411,8 +412,12 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
411412
// Constructor parameters are in scope when typing a parent.
412413
// While they can safely appear in a parent tree, to preserve
413414
// soundness we need to ensure they don't appear in a parent
414-
// type (#16270).
415-
val illegalRefs = parent.tpe.namedPartsWith(p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym))
415+
// type (#16270). We can strip any refinement of a parent type since
416+
// these refinements are split off from the parent type constructor
417+
// application `parent` in Namer and don't show up as parent types
418+
// of the class.
419+
val illegalRefs = parent.tpe.stripRefinement.namedPartsWith:
420+
p => p.symbol.is(ParamAccessor) && (p.symbol.owner eq sym)
416421
if illegalRefs.nonEmpty then
417422
report.error(
418423
em"The type of a class parent cannot refer to constructor parameters, but ${parent.tpe} refers to ${illegalRefs.map(_.name.show).mkString(",")}", parent.srcPos)

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ import TypeErasure.ErasedValueType
77
import Types.*, Contexts.*, Symbols.*, Flags.*, Decorators.*
88
import Names.Name
99

10-
object TypeUtils {
10+
object TypeUtils:
1111
/** A decorator that provides methods on types
1212
* that are needed in the transformer pipeline.
1313
*/
14-
extension (self: Type) {
14+
extension (self: Type)
1515

1616
def isErasedValueType(using Context): Boolean =
1717
self.isInstanceOf[ErasedValueType]
@@ -104,5 +104,11 @@ object TypeUtils {
104104
case _ =>
105105
val cls = self.underlyingClassRef(refinementOK = false).typeSymbol
106106
cls.isTransparentClass && (!traitOnly || cls.is(Trait))
107-
}
108-
}
107+
108+
/** Strip all outer refinements off this type */
109+
def stripRefinement: Type = self match
110+
case self: RefinedOrRecType => self.parent.stripRefinement
111+
case seld => self
112+
113+
end TypeUtils
114+

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ object Checking {
198198
* and that the instance conforms to the self type of the created class.
199199
*/
200200
def checkInstantiable(tp: Type, srcTp: Type, pos: SrcPos)(using Context): Unit =
201-
tp.underlyingClassRef(refinementOK = false) match
201+
tp.underlyingClassRef(refinementOK = true) match
202202
case tref: TypeRef =>
203203
val cls = tref.symbol
204204
if (cls.isOneOf(AbstractOrTrait)) {

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,8 +1500,13 @@ class Namer { typer: Typer =>
15001500
core match
15011501
case Select(New(tpt), nme.CONSTRUCTOR) =>
15021502
val targs1 = targs map (typedAheadType(_))
1503-
val ptype = typedAheadType(tpt).tpe appliedTo targs1.tpes
1504-
if (ptype.typeParams.isEmpty) ptype
1503+
val ptype = typedAheadType(tpt).tpe.appliedTo(targs1.tpes)
1504+
if ptype.typeParams.isEmpty
1505+
//&& !ptype.dealias.typeSymbol.primaryConstructor.info.finalResultType.isInstanceOf[RefinedType]
1506+
&& !ptype.dealias.typeSymbol.is(Dependent)
1507+
|| ctx.erasedTypes
1508+
then
1509+
ptype
15051510
else
15061511
if (denot.is(ModuleClass) && denot.sourceModule.isOneOf(GivenOrImplicit))
15071512
missingType(denot.symbol, "parent ")(using creationContext)
@@ -1539,7 +1544,7 @@ class Namer { typer: Typer =>
15391544
if (cls.isRefinementClass) ptype
15401545
else {
15411546
val pt = checkClassType(ptype, parent.srcPos,
1542-
traitReq = parent ne parents.head, stablePrefixReq = true)
1547+
traitReq = parent ne parents.head, stablePrefixReq = true, refinementOK = true)
15431548
if (pt.derivesFrom(cls)) {
15441549
val addendum = parent match {
15451550
case Select(qual: Super, _) if Feature.migrateTo3 =>
@@ -1605,14 +1610,52 @@ class Namer { typer: Typer =>
16051610
completeConstructor(denot)
16061611
denot.info = tempInfo.nn
16071612

1608-
val parentTypes = defn.adjustForTuple(cls, cls.typeParams,
1609-
defn.adjustForBoxedUnit(cls,
1610-
addUsingTraits(
1611-
ensureFirstIsClass(cls, parents.map(checkedParentType(_)))
1612-
)
1613-
)
1614-
)
1615-
typr.println(i"completing $denot, parents = $parents%, %, parentTypes = $parentTypes%, %")
1613+
/** The refinements coming from all parent class constructor applications */
1614+
val parentRefinements = mutable.LinkedHashMap[Name, Type]()
1615+
1616+
/** Split refinements off parent type and add them to `parentRefinements` */
1617+
def separateRefinements(tp: Type): Type = tp match
1618+
case RefinedType(tp1, rname, rinfo) =>
1619+
try separateRefinements(tp1)
1620+
finally
1621+
parentRefinements(rname) = parentRefinements.get(rname) match
1622+
case Some(tp) => tp & rinfo
1623+
case None => rinfo
1624+
case tp => tp
1625+
1626+
/** Add all parent refinements to the result type of the `info` of
1627+
* the class constructor. Parent refinements refer to parameter accessors
1628+
* in the current class. These have to be mapped to the paramRefs of the
1629+
* constructor info.
1630+
* @param info The (remaining part) of the constructor info
1631+
* @param nameToParamRef The map from parameter names to paramRefs of
1632+
* previously encountered parts of `info`.
1633+
*/
1634+
def integrateParentRefinements(info: Type, nameToParamRef: Map[Name, Type]): Type = info match
1635+
case info: MethodOrPoly =>
1636+
info.derivedLambdaType(resType =
1637+
integrateParentRefinements(info.resType,
1638+
nameToParamRef ++ info.paramNames.zip(info.paramRefs)))
1639+
case _ =>
1640+
val mapParams = new TypeMap:
1641+
def apply(t: Type) = t match
1642+
case t: TermRef if t.symbol.is(ParamAccessor) && t.symbol.owner == cls =>
1643+
nameToParamRef(t.name)
1644+
case _ =>
1645+
mapOver(t)
1646+
parentRefinements.foldLeft(info): (info, refinement) =>
1647+
val (rname, rinfo) = refinement
1648+
RefinedType(info, rname, mapParams(rinfo))
1649+
1650+
val parentTypes =
1651+
defn.adjustForTuple(cls, cls.typeParams,
1652+
defn.adjustForBoxedUnit(cls,
1653+
addUsingTraits(
1654+
ensureFirstIsClass(cls, parents.map(checkedParentType(_)))
1655+
))).map(separateRefinements)
1656+
1657+
typr.println(i"completing $denot, parents = $parents%, %, stripped parent types = $parentTypes%, %")
1658+
typr.println(i"constr type = ${cls.primaryConstructor.infoOrCompleter}, refinements = ${parentRefinements.toList}")
16161659

16171660
if (impl.derived.nonEmpty) {
16181661
val (derivingClass, derivePos) = original.removeAttachment(desugar.DerivingCompanion) match {
@@ -1627,6 +1670,10 @@ class Namer { typer: Typer =>
16271670
denot.info = tempInfo.nn.finalized(parentTypes)
16281671
tempInfo = null // The temporary info can now be garbage-collected
16291672

1673+
if parentRefinements.nonEmpty then
1674+
val constr = cls.primaryConstructor
1675+
constr.info = integrateParentRefinements(constr.info, Map())
1676+
cls.setFlag(Dependent)
16301677
Checking.checkWellFormed(cls)
16311678
if (isDerivedValueClass(cls)) cls.setFlag(Final)
16321679
cls.info = avoidPrivateLeaks(cls)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4360,7 +4360,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
43604360
cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName)
43614361
case _ =>
43624362
errorTree(tree, em"cannot convert from $tree to an instance creation expression")
4363-
val tycon = tree.tpe.widen.finalResultType.underlyingClassRef(refinementOK = false)
4363+
val tycon = tree.tpe.widen.finalResultType.underlyingClassRef(refinementOK = true)
43644364
typed(
43654365
untpd.Select(
43664366
untpd.New(untpd.TypedSplice(tpt.withType(tycon))),

tests/neg/i3964.scala

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,3 @@ object Test1:
1010
trait Foo { val x: Animal }
1111
val foo: Foo { val x: Cat } = new Foo { val x = new Cat } // error, but should work
1212

13-
object Test2:
14-
abstract class Bar(tracked val x: Animal)
15-
val b = new Bar(new Cat)
16-
val bar: Bar { val x: Cat } = new Bar(new Cat) // ok
17-
18-
trait Foo(tracked val x: Animal)
19-
val foo: Foo { val x: Cat } = new Foo(new Cat) // ok
20-
21-
object Test3:
22-
trait Vec(tracked val size: Int)
23-
class Vec8 extends Vec(8)
24-
25-
abstract class Lst(tracked val size: Int)
26-
class Lst8 extends Lst(8)
27-
28-
val v8a: Vec { val size: 8 } = new Vec8 // error, but should work
29-
val v8b: Vec { val size: 8 } = new Vec(8) // ok
30-
31-
val l8a: Lst { val size: 8 } = new Lst8 // error, but should work
32-
val l8b: Lst { val size: 8 } = new Lst(8) // ok

tests/pos/i3964.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
trait Animal
2+
class Dog extends Animal
3+
class Cat extends Animal
4+
5+
object Test2:
6+
class Bar(tracked val x: Animal)
7+
val b = new Bar(new Cat)
8+
val bar: Bar { val x: Cat } = new Bar(new Cat) // ok
9+
10+
trait Foo(tracked val x: Animal)
11+
val foo: Foo { val x: Cat } = new Foo(new Cat) {} // ok
12+
13+
object Test3:
14+
trait Vec(tracked val size: Int)
15+
class Vec8 extends Vec(8)
16+
17+
abstract class Lst(tracked val size: Int)
18+
class Lst8 extends Lst(8)
19+
20+
val v8a: Vec { val size: 8 } = new Vec8
21+
val v8b: Vec { val size: 8 } = new Vec(8) {}
22+
23+
val l8a: Lst { val size: 8 } = new Lst8
24+
val l8b: Lst { val size: 8 } = new Lst(8) {}
25+
26+
class VecN(tracked val n: Int) extends Vec(n)
27+
class Vec9 extends VecN(9)
28+
val v9a = VecN(9)
29+
val _: Vec { val size: 9 } = v9a
30+
val v9b = Vec9()
31+
val _: Vec { val size: 9 } = v9b

0 commit comments

Comments
 (0)