Skip to content

Commit

Permalink
Make most tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
denhie committed Jul 17, 2023
1 parent 8e31679 commit 1b1d5b6
Show file tree
Hide file tree
Showing 26 changed files with 227 additions and 179 deletions.
8 changes: 4 additions & 4 deletions effekt/jvm/src/main/scala/effekt/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ trait LSPServer extends kiama.util.Server[Tree, EffektConfig, EffektError] with
effects <- fun.symbol.annotatedEffects
} yield (result, effects)
if ann.map {
case (List(y), eff1) => needsUpdate((y, eff1), (tpe, eff))
case _ => ??? // TODO MRV
(y, eff1) => needsUpdate((y, eff1), (tpe, eff))
}.getOrElse(true)
res <- CodeAction("Update return type with inferred effects", fun.ret, s": $tpe / $eff")
} yield res
Expand All @@ -233,10 +232,11 @@ trait LSPServer extends kiama.util.Server[Tree, EffektConfig, EffektError] with
}
} yield res

def needsUpdate(annotated: (ValueType, Effects), inferred: (ValueType, Effects))(using Context): Boolean = {
def needsUpdate(annotated: (List[ValueType], Effects), inferred: (List[ValueType], Effects))(using Context): Boolean = {
val (tpe1, effs1) = annotated
val (tpe2, effs2) = inferred
tpe1 != tpe2 || effs1 != effs2

tpe1.size != tpe2.size || tpe1.zip(tpe2).forall { case (t1, t2) => t1 != t2 } || effs1 != effs2
}

case class CaptureInfo(location: Location, captureText: String)
Expand Down
58 changes: 23 additions & 35 deletions effekt/shared/src/main/scala/effekt/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,8 @@ object Typer extends Phase[NameResolved, Typechecked] {
handlers foreach Context.withFocus { h =>
given Captures = continuationCaptHandled

ret match {
case List(ret) => {
val Result(_, usedEffects) = checkImplementation(h.impl, Some((ret, continuationCapt)))
handlerEffs = handlerEffs ++ usedEffects
}
case _ => ??? // TODO MRV
}
val Result(_, usedEffects) = checkImplementation(h.impl, Some((ret, continuationCapt)))
handlerEffs = handlerEffs ++ usedEffects

}

Expand Down Expand Up @@ -307,7 +302,8 @@ object Typer extends Phase[NameResolved, Typechecked] {
// Clauses could in general be empty if there are no constructors
// In that case the scrutinee couldn't have been constructed and
// we can unify with everything.
Result(Context.join(tpes: _*), resEff) // TODO: wie oben

Result(Context.join(tpes: _*), resEff)

case source.Hole(stmt) =>
val Result(tpe, effs) = checkStmt(stmt, None)
Expand All @@ -320,7 +316,7 @@ object Typer extends Phase[NameResolved, Typechecked] {
/**
* The [[continuationDetails]] are only provided, if a continuation is captured (that is for implementations as part of effect handlers).
*/
def checkImplementation(impl: source.Implementation, continuationDetails: Option[(ValueType, CaptUnificationVar)])(using Context, Captures): Result[InterfaceType] = Context.focusing(impl) {
def checkImplementation(impl: source.Implementation, continuationDetails: Option[(List[ValueType], CaptUnificationVar)])(using Context, Captures): Result[InterfaceType] = Context.focusing(impl) {
case source.Implementation(sig, clauses) =>

var handlerEffects: ConcreteEffects = Pure
Expand Down Expand Up @@ -415,15 +411,15 @@ object Typer extends Phase[NameResolved, Typechecked] {
// resume { e }
val resumeType = FunctionType(Nil, cparams, Nil, Nil, tpe, otherEffs)
val resumeCapt = CaptureParam(Name.local("resumeBlock"))
FunctionType(Nil, List(resumeCapt), Nil, List(resumeType), List(ret), Effects.Pure)
FunctionType(Nil, List(resumeCapt), Nil, List(resumeType), ret, Effects.Pure)
} else {
// resume(v)
FunctionType(Nil, Nil, tpe, Nil, List(ret), Effects.Pure)
FunctionType(Nil, Nil, tpe, Nil, ret, Effects.Pure)
}

Context.bind(Context.symbolOf(resume).asBlockSymbol, resumeType, continuationCapt)

body checkAgainst List(ret)
body checkAgainst ret
}

handlerEffects = handlerEffects ++ effs
Expand Down Expand Up @@ -534,7 +530,7 @@ object Typer extends Phase[NameResolved, Typechecked] {
}

bindings
} match { case res => Context.annotateInferredType(pattern, sc); res }
} match { case res => Context.annotateInferredType(pattern, List(sc)); res } // TODO MRV: List(sc)

//</editor-fold>

Expand All @@ -554,20 +550,14 @@ object Typer extends Phase[NameResolved, Typechecked] {
Result(r, eff1 ++ eff2)

// TODO MRV: Verbieten? (return (1, return(2,3))
case source.Return(e) => expected match {
case None => {
val check = e map { checkExpr(_, None) }
val tpes = check flatMap { _.tpe }
val effs = check map { _.effects } reduce { _ ++ _ }
Result(tpes, effs)
case source.Return(e) =>
val checked = expected match {
case None => e map { checkExpr(_, None) }
case Some(tpe) => (e zip tpe) map { case (e, tpe) => checkExpr(e, Some(List(tpe))) }
}
case Some(tpe) => {
val checked = (e zip tpe) map { case (e, tpe) => checkExpr(e, Some(List(tpe))) }
val tpes = checked flatMap { _.tpe }
val effs = checked map { _.effects } reduce { _ ++ _ }
Result(tpes, effs)
}
}
val tpes = checked flatMap { _.tpe }
val effs = checked map { _.effects } reduce { _ ++ _ }
Result(tpes, effs)

case source.BlockStmt(stmts) => Context in { checkStmt(stmts, expected) }
}
Expand Down Expand Up @@ -743,10 +733,7 @@ object Typer extends Phase[NameResolved, Typechecked] {
case Some(t) => binding checkAgainst List(t)
case None => checkStmt(binding, None)
}
val stTpe = tpeBind match {
case List(tpeBind) => TState(tpeBind)
case _ => ??? // TODO MRV
}
val stTpe = TState(tpeBind)

// to allocate into the region, it needs to be live...
usingCapture(stCapt)
Expand Down Expand Up @@ -1197,13 +1184,13 @@ object Typer extends Phase[NameResolved, Typechecked] {
def matchPattern(scrutinee: ValueType, patternTpe: ValueType, pattern: source.MatchPattern)(using Context): Unit =
Context.requireSubtype(scrutinee, patternTpe, ErrorContext.PatternMatch(pattern))

// TODO MRV
// TODO MRV: remove?
def matchExpected(got: ValueType, expected: ValueType)(using Context): Unit =
Context.requireSubtype(got, expected,
ErrorContext.Expected(Context.unification(got), Context.unification(expected), Context.focus))

def matchExpected(got: List[ValueType], expected: List[ValueType])(using Context): Unit = {
if (got.length != expected.length) Context.error("Expected " + expected.length + " arguments, but got " + got.length)
if (got.length != expected.length) Context.error("Expected " + expected.length + " results, but got " + got.length)

got zip expected foreach { (g, e) => Context.requireSubtype(g, e,
ErrorContext.Expected(Context.unification(g), Context.unification(e), Context.focus)) }
Expand Down Expand Up @@ -1497,8 +1484,8 @@ trait TyperOps extends ContextOps { self: Context =>

//<editor-fold desc="(5) Inferred Information for LSP">

private[typer] def annotateInferredType(t: Tree, e: ValueType) =
annotations.update(Annotations.InferredValueType, t, e)
/*private[typer] def annotateInferredType(t: Tree, e: ValueType) =
annotations.update(Annotations.InferredValueType, t, e)*/ // TODO MRV: remove?

private[typer] def annotateInferredType(t: Tree, e: List[ValueType]) =
annotations.update(Annotations.InferredValueTypeList, t, e)
Expand Down Expand Up @@ -1552,7 +1539,8 @@ trait TyperOps extends ContextOps { self: Context =>

// Update and write out all inferred types and captures for LSP support
// This info is currently also used by Transformer!
annotations.updateAndCommit(Annotations.InferredValueType) { case (t, tpe) => subst.substitute(tpe) }
//annotations.updateAndCommit(Annotations.InferredValueType) { case (t, tpe) => subst.substitute(tpe) } // TODO MRV: remove?
annotations.updateAndCommit(Annotations.InferredValueTypeList) { case (t, tpe) => tpe.map(subst.substitute) }
annotations.updateAndCommit(Annotations.InferredBlockType) { case (t, tpe) => subst.substitute(tpe) }
annotations.updateAndCommit(Annotations.InferredEffect) { case (t, effs) => subst.substitute(effs) }

Expand Down
16 changes: 8 additions & 8 deletions effekt/shared/src/main/scala/effekt/context/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ object Annotations {
*
* Important for finding the types of temporary variables introduced by transformation
* Can also be used by LSP server to display type information for type-checked trees
*/
*
val InferredValueType = Annotation[source.Tree, symbols.ValueType](
"InferredValueType",
"the inferred type of"
)
)*/

/**
* The type as inferred by typer at a given position in the tree
Expand Down Expand Up @@ -358,10 +358,10 @@ trait AnnotationsDB { self: Context =>
def typeArguments(c: source.CallLike): List[symbols.ValueType] =
annotation(Annotations.TypeArguments, c)

def inferredTypeOption(t: source.Tree): Option[ValueType] =
annotationOption(Annotations.InferredValueType, t)
def inferredTypeOption(t: source.Tree): Option[List[ValueType]] =
annotationOption(Annotations.InferredValueTypeList, t)

def inferredTypeOf(t: source.Tree): ValueType =
def inferredTypeOf(t: source.Tree): List[ValueType] =
inferredTypeOption(t).getOrElse {
panic(s"Internal Error: Missing type of source expression: '${t}'")
}
Expand All @@ -382,15 +382,15 @@ trait AnnotationsDB { self: Context =>
panic(s"Internal Error: Missing effect of source expression: '${t}'")
}

def inferredTypeAndEffectOption(t: source.Tree): Option[(ValueType, Effects)] =
def inferredTypeAndEffectOption(t: source.Tree): Option[(List[ValueType], Effects)] =
for {
tpe <- inferredTypeOption(t)
eff <- inferredEffectOption(t)
} yield (tpe, eff)

def inferredTypeAndEffectOf(t: source.Tree): (ValueType, Effects) =
def inferredTypeAndEffectOf(t: source.Tree): (List[ValueType], Effects) =
inferredTypeAndEffectOption(t).getOrElse {
panic(s"Internal Error: Missing type of source expression: '${t}'")
panic(s"Internal Error: Missing type and effect of source expression: '${t}'")
}

def inferredCapture(t: source.Tree): symbols.CaptureSet =
Expand Down
18 changes: 9 additions & 9 deletions effekt/shared/src/main/scala/effekt/core/Analyses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def collectAliases(statement: Stmt): Map[Id, Id] =
definitions.map(collectAliases).fold(Map[Id, Id]())(_ ++ _) ++ collectAliases(body)

case Return(p) =>
collectAliases(p)
p.map(collectAliases).fold(Map[Id, Id]())(_ ++ _)

case Val(_, binding, body) =>
collectAliases(binding) ++ collectAliases(body)
Expand Down Expand Up @@ -202,7 +202,7 @@ def collectFunctionDefinitions(statement: Stmt): Map[Id, Block] =
collectFunctionDefinitions(body)

case Return(p) =>
collectFunctionDefinitions(p)
p.map(collectFunctionDefinitions).fold(Map[Id, Block]())(_ ++ _)

case Val(_, binding, body) =>
collectFunctionDefinitions(binding) ++ collectFunctionDefinitions(body)
Expand Down Expand Up @@ -342,7 +342,7 @@ def countFunctionOccurencesWorker(statement: Stmt)(using count: mutable.Map[Id,
countFunctionOccurencesWorker(body)

case Return(p) =>
countFunctionOccurencesWorker(p)
p.foreach(countFunctionOccurencesWorker)

case Val(_, binding, body) =>
countFunctionOccurencesWorker(binding)
Expand Down Expand Up @@ -456,8 +456,8 @@ def findRecursiveFunctions(statement: Stmt): Set[Id] =
case Scope(definitions, body) =>
definitions.map(findRecursiveFunctions).fold(Set[Id]())(_ ++ _) ++ findRecursiveFunctions(body)

case Return(expr) =>
findRecursiveFunctions(expr)
case Return(exprs) =>
exprs.map(findRecursiveFunctions).fold(Set[Id]())(_ ++ _)

case Val(_, binding, body) =>
findRecursiveFunctions(binding) ++ findRecursiveFunctions(body)
Expand Down Expand Up @@ -589,8 +589,8 @@ def findStaticArgumentsWorker(statement: Stmt)(using params: StaticParamsUnfinis
definitions.foreach(findStaticArgumentsWorker)
findStaticArgumentsWorker(body)

case Return(expr) =>
findStaticArgumentsWorker(expr)
case Return(exprs) =>
exprs.foreach(findStaticArgumentsWorker)

case Val(_, binding, body) =>
findStaticArgumentsWorker(binding)
Expand Down Expand Up @@ -719,8 +719,8 @@ def size(statement: Stmt): Int =
case Scope(definitions, body) =>
1 + definitions.map(size).sum + size(body)

case Return(expr) =>
1 + size(expr)
case Return(exprs) =>
1 + exprs.map(size).sum

case Val(_, binding, body) =>
1 + size(binding) + size(body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ object DirectStyleMutableState extends Phase[CoreTransformed, CoreTransformed] {

case Get(x, tpe) =>
val id = Id("tmp")
Let(id, Get(x, tpe), Stmt.Return(Pure.ValueVar(id, tpe.result)))
Let(id, Get(x, tpe), Stmt.Return(List(Pure.ValueVar(id, tpe.result match {
case List(tpe) => tpe
case _ => ??? // TODO MRV
}))))

case Put(x, tpe, v) =>
val id = Id("tmp")
Let(id, Put(x, tpe, v), Stmt.Return(Pure.ValueVar(id, Type.TUnit)))
Let(id, Put(x, tpe, v), Stmt.Return(List(Pure.ValueVar(id, Type.TUnit))))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ object MakeStackSafe extends Phase[CoreTransformed, CoreTransformed] {
}

// [[ s ]] = val tmp = return (); s
def thunk(s: Stmt): Stmt = Stmt.Val(TmpValue(), Stmt.Return(Literal((), core.Type.TUnit)), s)
def thunk(s: Stmt): Stmt = Stmt.Val(TmpValue(), Stmt.Return(List(Literal((), core.Type.TUnit))), s)
}
22 changes: 11 additions & 11 deletions effekt/shared/src/main/scala/effekt/core/Optimizations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dealiasing(statement: Stmt)(using aliases: Map[Id, Id]): Stmt =
}.map(dealiasing), dealiasing(body))

case Return(p) =>
Return(dealiasing(p))
Return(p.map(dealiasing))

case Val(id, binding, body) =>
Val(id, dealiasing(binding), dealiasing(body))
Expand Down Expand Up @@ -214,7 +214,7 @@ def removeUnusedFunctionsWorker(statement: Stmt)(using count: Map[Id, Int], recu
else Scope(defs, removeUnusedFunctionsWorker(body))

case Return(p) =>
Return(removeUnusedFunctionsWorker(p))
Return(p.map(removeUnusedFunctionsWorker))

case Val(id, binding, body) =>
Val(id, removeUnusedFunctionsWorker(binding), removeUnusedFunctionsWorker(body))
Expand Down Expand Up @@ -331,7 +331,7 @@ def staticArgumentTransformationWorker(statement: Stmt)(using recursiveFunctions
Scope(definitions.map(staticArgumentTransformationWorker), staticArgumentTransformationWorker(body))

case Return(expr) =>
Return(staticArgumentTransformationWorker(expr))
Return(expr.map(staticArgumentTransformationWorker))

case Val(id, binding, body) =>
Val(id, staticArgumentTransformationWorker(binding), staticArgumentTransformationWorker(body))
Expand Down Expand Up @@ -437,8 +437,8 @@ def replaceCalls(statement: Stmt)(using newName: Id, params: StaticParams): Stmt
case Scope(definitions, body) =>
Scope(definitions.map(replaceCalls), replaceCalls(body))

case Return(expr) =>
Return(replaceCalls(expr))
case Return(exprs) =>
Return(exprs.map(replaceCalls))

case Val(id, binding, body) =>
Val(id, replaceCalls(binding), replaceCalls(body))
Expand Down Expand Up @@ -589,8 +589,8 @@ def inliningWorker(statement: Stmt)(using inlines: Map[Id, Block]): Stmt =
case d@Definition.Def(id, _) => inliningWorker(d)(using inlines - id)
case l@Definition.Let(id,_) => inliningWorker(l)(using inlines - id)}, inliningWorker(body))

case Return(expr) =>
Return(inliningWorker(expr))
case Return(exprs) =>
Return(exprs.map(inliningWorker))

case Val(id, binding, body) =>
Val(id, inliningWorker(binding), inliningWorker(body))
Expand Down Expand Up @@ -726,8 +726,8 @@ def constantPropagation(statement: Stmt)(using constants: Map[Id, Literal]): Stm
val newConstants = extraConstants ++ constants
Scope(defsNoConstants.map(constantPropagation(_)(using newConstants)), constantPropagation(body)(using newConstants))

case Return(expr) =>
Return(constantPropagation(expr))
case Return(exprs) =>
Return(exprs.map(constantPropagation))

case Val(id, binding, body) =>
Val(id, constantPropagation(binding), constantPropagation(body))
Expand Down Expand Up @@ -831,8 +831,8 @@ def betaReduction(statement: Stmt): Stmt =
case Scope(definitions, body) =>
Scope(definitions.map(betaReduction), betaReduction(body))

case Return(expr) =>
Return(betaReduction(expr))
case Return(exprs) =>
Return(exprs.map(betaReduction))

case Val(id, binding, body) =>
Val(id, betaReduction(binding), betaReduction(body))
Expand Down
Loading

0 comments on commit 1b1d5b6

Please sign in to comment.