Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ class Definitions {
@tu lazy val CompiletimeTesting_ErrorKind: Symbol = ctx.requiredModule("scala.compiletime.testing.ErrorKind")
@tu lazy val CompiletimeTesting_ErrorKind_Parser: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Parser")
@tu lazy val CompiletimeTesting_ErrorKind_Typer: Symbol = CompiletimeTesting_ErrorKind.requiredMethod("Typer")
@tu lazy val CompiletimeOpsPackageObject: Symbol = ctx.requiredModule("scala.compiletime.ops.package")
@tu lazy val CompiletimeOpsPackageObjectInt: Symbol = ctx.requiredModule("scala.compiletime.ops.package.int")
@tu lazy val CompiletimeOpsPackageObjectString: Symbol = ctx.requiredModule("scala.compiletime.ops.package.string")
@tu lazy val CompiletimeOpsPackageObjectBoolean: Symbol = ctx.requiredModule("scala.compiletime.ops.package.boolean")

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down Expand Up @@ -898,6 +902,26 @@ class Definitions {
final def isCompiletime_S(sym: Symbol)(implicit ctx: Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimePackageObject.moduleClass

private val compiletimePackageTypes: Set[Name] = Set(tpnme.Equals, tpnme.NotEquals)
private val compiletimePackageIntTypes: Set[Name] = Set(
tpnme.Plus, tpnme.Minus, tpnme.Times, tpnme.Div, tpnme.Mod,
tpnme.Lt, tpnme.Gt, tpnme.Ge, tpnme.Le,
tpnme.Abs, tpnme.Negate, tpnme.Min, tpnme.Max, tpnme.ToString,
)
private val compiletimePackageBooleanTypes: Set[Name] = Set(tpnme.Not, tpnme.Xor, tpnme.And, tpnme.Or)
private val compiletimePackageStringTypes: Set[Name] = Set(tpnme.Plus)

final def isCompiletimeAppliedType(sym: Symbol)(implicit ctx: Context): Boolean = {
def isOpsPackageObjectAppliedType: Boolean =
sym.owner == CompiletimeOpsPackageObject.moduleClass && compiletimePackageTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectInt.moduleClass && compiletimePackageIntTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectBoolean.moduleClass && compiletimePackageBooleanTypes.contains(sym.name) ||
sym.owner == CompiletimeOpsPackageObjectString.moduleClass && compiletimePackageStringTypes.contains(sym.name)

sym.isType && (isCompiletime_S(sym) || isOpsPackageObjectAppliedType)
}


// ----- Symbol sets ---------------------------------------------------

@tu lazy val AbstractFunctionType: Array[TypeRef] = mkArityArray("scala.runtime.AbstractFunction", MaxImplementedFunctionArity, 0)
Expand Down
23 changes: 22 additions & 1 deletion compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,34 @@ object StdNames {
final val Product: N = "Product"
final val PartialFunction: N = "PartialFunction"
final val PrefixType: N = "PrefixType"
final val S: N = "S"
final val Serializable: N = "Serializable"
final val Singleton: N = "Singleton"
final val Throwable: N = "Throwable"
final val IOOBException: N = "IndexOutOfBoundsException"
final val FunctionXXL: N = "FunctionXXL"

final val Abs: N = "Abs"
final val And: N = "&&"
final val Div: N = "/"
final val Equals: N = "=="
final val Ge: N = ">="
final val Gt: N = ">"
final val Le: N = "<="
final val Lt: N = "<"
final val Max: N = "Max"
final val Min: N = "Min"
final val Minus: N = "-"
final val Mod: N = "%"
final val Negate: N = "Negate"
final val Not: N = "!"
final val NotEquals: N = "!="
final val Or: N = "||"
final val Plus: N = "+"
final val S: N = "S"
final val Times: N = "*"
final val ToString: N = "ToString"
final val Xor: N = "^"

final val ClassfileAnnotation: N = "ClassfileAnnotation"
final val ClassManifest: N = "ClassManifest"
final val Enum: N = "Enum"
Expand Down
8 changes: 6 additions & 2 deletions compiler/src/dotty/tools/dotc/core/TypeApplications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,16 @@ class TypeApplications(val self: Type) extends AnyVal {
// just eta-reduction (ignoring variance annotations).
// See i2201*.scala for examples where more aggressive
// reduction would break type inference.
dealiased.paramRefs == dealiasedArgs
dealiased.paramRefs == dealiasedArgs ||
defn.isCompiletimeAppliedType(tyconBody.typeSymbol)
case _ => false
}
}
if ((dealiased eq stripped) || followAlias)
try dealiased.instantiate(args)
try {
val instantiated = dealiased.instantiate(args)
if (followAlias) instantiated.normalized else instantiated
}
catch { case ex: IndexOutOfBoundsException => AppliedType(self, args) }
else AppliedType(self, args)
}
Expand Down
18 changes: 15 additions & 3 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
compareLower(bounds(param2), tyconIsTypeRef = false)
case tycon2: TypeRef =>
isMatchingApply(tp1) ||
defn.isCompiletime_S(tycon2.symbol) && compareS(tp2, tp1, fromBelow = true) || {
defn.isCompiletimeAppliedType(tycon2.symbol) && compareCompiletimeAppliedType(tp2, tp1, fromBelow = true) || {
tycon2.info match {
case info2: TypeBounds =>
compareLower(info2, tyconIsTypeRef = true)
Expand Down Expand Up @@ -1005,7 +1005,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case tycon1: TypeRef =>
val sym = tycon1.symbol
!sym.isClass && {
defn.isCompiletime_S(sym) && compareS(tp1, tp2, fromBelow = false) ||
defn.isCompiletimeAppliedType(sym) && compareCompiletimeAppliedType(tp1, tp2, fromBelow = false) ||
recur(tp1.superType, tp2) ||
tryLiftedToThis1
}
Expand All @@ -1015,7 +1015,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
false
}

/** Compare `tp` of form `S[arg]` with `other`, via ">:>` if fromBelow is true, "<:<" otherwise.
/** Compare `tp` of form `S[arg]` with `other`, via ">:>" if fromBelow is true, "<:<" otherwise.
* If `arg` is a Nat constant `n`, proceed with comparing `n + 1` and `other`.
* Otherwise, if `other` is a Nat constant `n`, proceed with comparing `arg` and `n - 1`.
*/
Expand All @@ -1037,6 +1037,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
case _ => false
}

/** Compare `tp` of form `tycon[...args]`, where `tycon` is a scala.compiletime type,
* with `other` via ">:>" if fromBelow is true, "<:<" otherwise.
* Delegates to compareS if `tycon` is scala.compiletime.S. Otherwise, constant folds if possible.
*/
def compareCompiletimeAppliedType(tp: AppliedType, other: Type, fromBelow: Boolean): Boolean = {
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else {
val folded = tp.tryCompiletimeConstantFold
if (fromBelow) recur(other, folded) else recur(folded, other)
}
}

/** Like tp1 <:< tp2, but returns false immediately if we know that
* the case was covered previously during subtyping.
*/
Expand Down
98 changes: 88 additions & 10 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3595,19 +3595,97 @@ object Types {
case _ =>
NoType
}
if (defn.isCompiletime_S(tycon.symbol) && args.length == 1)
trace(i"normalize S $this", typr, show = true) {
args.head.normalized match {
case ConstantType(Constant(n: Int)) if n >= 0 && n < Int.MaxValue =>
ConstantType(Constant(n + 1))
case none => tryMatchAlias
}
}
else tryMatchAlias

tryCompiletimeConstantFold.orElse(tryMatchAlias)

case _ =>
NoType
}

def tryCompiletimeConstantFold(implicit ctx: Context): Type = tycon match {
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
def constValue(tp: Type): Option[Any] = tp match {
case ConstantType(Constant(n)) => Some(n)
case _ => None
}

def boolValue(tp: Type): Option[Boolean] = tp match {
case ConstantType(Constant(n: Boolean)) => Some(n)
case _ => None
}

def intValue(tp: Type): Option[Int] = tp match {
case ConstantType(Constant(n: Int)) => Some(n)
case _ => None
}

def stringValue(tp: Type): Option[String] = tp match {
case ConstantType(Constant(n: String)) => Some(n)
case _ => None
}

def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)

def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
extractor(args.head.normalized).map(a => ConstantType(Constant(op(a))))

def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
for {
a <- extractor(args.head.normalized)
b <- extractor(args.tail.head.normalized)
} yield ConstantType(Constant(op(a, b)))

trace(i"compiletime constant fold $this", typr, show = true) {
val name = tycon.symbol.name
val owner = tycon.symbol.owner
val nArgs = args.length
val constantType =
if (owner == defn.CompiletimePackageObject.moduleClass) name match {
case tpnme.S if nArgs == 1 => constantFold1(natValue, _ + 1)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObject.moduleClass) name match {
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectInt.moduleClass) name match {
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Division by 0")
case (a, b) => a / b
})
case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
case (_, 0) => throw new TypeError("Modulo by 0")
case (a, b) => a % b
})
case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectString.moduleClass) name match {
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
case _ => None
} else if (owner == defn.CompiletimeOpsPackageObjectBoolean.moduleClass) name match {
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
case _ => None
} else None

constantType.getOrElse(NoType)
}

case _ => NoType
}

def lowerBound(implicit ctx: Context): Type = tycon.stripTypeVar match {
case tycon: TypeRef =>
tycon.info match {
Expand Down Expand Up @@ -3974,7 +4052,7 @@ object Types {
myReduced =
trace(i"reduce match type $this $hashCode", typr, show = true) {
try
typeComparer.matchCases(scrutinee, cases)(trackingCtx)
typeComparer.matchCases(scrutinee.normalized, cases)(trackingCtx)
catch {
case ex: Throwable =>
handleRecursive("reduce type ", i"$scrutinee match ...", ex)
Expand Down
38 changes: 38 additions & 0 deletions library/src/scala/compiletime/ops/package.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package scala.compiletime

import scala.annotation.infix

package object ops {
@infix type ==[X <: AnyVal, Y <: AnyVal] <: Boolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe == and != should be split as well. Either everything is split according to the supertype or non of it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved them into scala.compiletime.ops.any in 8c117c1.

The alternative was to duplicate them into each subpackage for each supported type, and then duplicate the constant folding code... which felt like a lot of duplication. Seeing that == and != are defined on Any, I think this solution makes the most sense. It also emphasizes that equality is between to Any values, and that 1 == "1" will return false.

@infix type !=[X <: AnyVal, Y <: AnyVal] <: Boolean

object string {
@infix type +[X <: String, Y <: String] <: String
}

object int {
@infix type +[X <: Int, Y <: Int] <: Int
@infix type -[X <: Int, Y <: Int] <: Int
@infix type *[X <: Int, Y <: Int] <: Int
@infix type /[X <: Int, Y <: Int] <: Int
@infix type %[X <: Int, Y <: Int] <: Int

@infix type <[X <: Int, Y <: Int] <: Boolean
@infix type >[X <: Int, Y <: Int] <: Boolean
@infix type >=[X <: Int, Y <: Int] <: Boolean
@infix type <=[X <: Int, Y <: Int] <: Boolean

type Abs[X <: Int] <: Int
type Negate[X <: Int] <: Int
type Min[X <: Int, Y <: Int] <: Int
type Max[X <: Int, Y <: Int] <: Int
type ToString[X <: Int] <: String
}

object boolean {
type ![X <: Boolean] <: Boolean
@infix type ^[X <: Boolean, Y <: Boolean] <: Boolean
@infix type &&[X <: Boolean, Y <: Boolean] <: Boolean
@infix type ||[X <: Boolean, Y <: Boolean] <: Boolean
}
}
23 changes: 23 additions & 0 deletions tests/neg/singleton-ops-boolean.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import scala.compiletime.ops.boolean._

object Test {
val t0: ![true] = false
val t1: ![false] = true
val t2: ![true] = true // error
val t3: ![false] = false // error

val t4: true && true = true
val t5: true && false = false
val t6: false && true = true // error
val t7: false && false = true // error

val t8: true ^ true = false
val t9: false ^ true = true
val t10: false ^ false = true // error
val t11: true ^ false = false // error

val t12: true || true = true
val t13: true || false = true
val t14 false || true = false // error
val t15: false || false = true // error
}
75 changes: 75 additions & 0 deletions tests/neg/singleton-ops-int.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import scala.compiletime.ops.int._

object Test {
summon[2 + 3 =:= 6 - 1]
summon[1763 =:= 41 * 43]
summon[2 + 2 =:= 3] // error
summon[29 * 31 =:= 900] // error
summon[Int <:< Int + 1] // error
summon[1 + Int <:< Int]

val t0: 2 + 3 = 5
val t1: 2 + 2 = 5 // error
val t2: -1 + 1 = 0
val t3: -5 + -5 = -11 // error

val t4: 10 * 20 = 200
val t5: 30 * 10 = 400 // error
val t6: -10 * 2 = -20
val t7: -2 * -2 = 4

val t8: 10 / 2 = 5
val t9: 11 / -2 = -5 // Integer division
val t10: 2 / 4 = 2 // error
val t11: -1 / 0 = 1 // error

val t12: 10 % 3 = 1
val t13: 12 % 2 = 1 // error
val t14: 1 % -3 = 1
val t15: -3 % 0 = 0 // error

val t16: 1 < 0 = false
val t17: 0 < 1 = true
val t18: 10 < 5 = true // error
val t19: 5 < 10 = false // error

val t20: 1 <= 0 = false
val t21: 1 <= 1 = true
val t22: 10 <= 5 = true // error
val t23: 5 <= 10 = false // error

val t24: 1 > 0 = true
val t25: 0 > 1 = false
val t26: 10 > 5 = false // error
val t27: 5 > 10 = true // error

val t28: 1 >= 1 = true
val t29: 0 >= 1 = false
val t30: 10 >= 5 = false // error
val t31: 5 >= 10 = true // error

val t32: Abs[0] = 0
val t33: Abs[-1] = 1
val t34: Abs[-1] = -1 // error
val t35: Abs[1] = -1 // error

val t36: Negate[-10] = 10
val t37: Negate[10] = -10
val t38: Negate[1] = 1 // error
val t39: Negate[-1] = -1 // error

val t40: Max[-1, 10] = 10
val t41: Max[4, 2] = 4
val t42: Max[2, 2] = 1 // error
val t43: Max[-1, -1] = 0 // error

val t44: Min[-1, 10] = -1
val t45: Min[4, 2] = 2
val t46: Min[2, 2] = 1 // error
val t47: Min[-1, -1] = 0 // error

val t48: ToString[213] = "213"
val t49: ToString[-1] = "-1"
val t50: ToString[0] = "-0" // error
val t51: ToString[200] = "100" // error
}
Loading