Skip to content

Commit

Permalink
Merge pull request scala#15663 from som-snytt/issue/13885-repl-parser…
Browse files Browse the repository at this point in the history
…-more-phaselike
  • Loading branch information
dwijnand authored Jul 26, 2022
2 parents 19eff87 + 5f4653d commit 62ca3fc
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 145 deletions.
8 changes: 0 additions & 8 deletions compiler/src/dotty/tools/repl/ReplCompilationUnit.scala

This file was deleted.

258 changes: 145 additions & 113 deletions compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ import dotty.tools.dotc.transform.PostTyper
import dotty.tools.dotc.typer.ImportInfo.{withRootImports, RootRef}
import dotty.tools.dotc.typer.TyperPhase
import dotty.tools.dotc.util.Spans._
import dotty.tools.dotc.util.{ParsedComment, SourceFile}
import dotty.tools.dotc.util.{ParsedComment, Property, SourceFile}
import dotty.tools.dotc.{CompilationUnit, Compiler, Run}
import dotty.tools.repl.results._

import scala.collection.mutable
import scala.util.chaining.given

/** This subclass of `Compiler` is adapted for use in the REPL.
*
Expand All @@ -29,12 +30,14 @@ import scala.collection.mutable
* - provides utility to query the type of an expression
* - provides utility to query the documentation of an expression
*/
class ReplCompiler extends Compiler {
class ReplCompiler extends Compiler:

override protected def frontendPhases: List[List[Phase]] = List(
List(new TyperPhase(addRootImports = false)),
List(new CollectTopLevelImports),
List(new PostTyper),
List(Parser()),
List(ReplPhase()),
List(TyperPhase(addRootImports = false)),
List(CollectTopLevelImports()),
List(PostTyper()),
)

def newRun(initCtx: Context, state: State): Run =
Expand All @@ -46,7 +49,7 @@ class ReplCompiler extends Compiler {

def importPreviousRun(id: Int)(using Context) = {
// we first import the wrapper object id
val path = nme.EMPTY_PACKAGE ++ "." ++ objectNames(id)
val path = nme.EMPTY_PACKAGE ++ "." ++ ReplCompiler.objectNames(id)
val ctx0 = ctx.fresh
.setNewScope
.withRootImports(RootRef(() => requiredModuleRef(path)) :: Nil)
Expand All @@ -67,117 +70,29 @@ class ReplCompiler extends Compiler {
}
run.suppressions.initSuspendedMessages(state.context.run)
run
end newRun

private val objectNames = mutable.Map.empty[Int, TermName]
private def packaged(stats: List[untpd.Tree])(using Context): untpd.PackageDef =
import untpd.*
PackageDef(Ident(nme.EMPTY_PACKAGE), stats)

private case class Definitions(stats: List[untpd.Tree], state: State)

private def definitions(trees: List[untpd.Tree], state: State): Definitions = inContext(state.context) {
import untpd._

// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
val flattened = trees match {
case List(Block(stats, expr)) =>
if (expr eq EmptyTree) stats // happens when expr is not an expression
else stats :+ expr
case _ =>
trees
}

var valIdx = state.valIndex
val defs = new mutable.ListBuffer[Tree]

/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
def maybeBumpValIdx(tree: Tree): Unit = tree match
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
case Some(n) if n >= valIdx => valIdx = n + 1
case _ =>
case _ =>

flattened.foreach {
case expr @ Assign(id: Ident, _) =>
// special case simple reassignment (e.g. x = 3)
// in order to print the new value in the REPL
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
defs += expr += assign
case expr if expr.isTerm =>
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
valIdx += 1
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
defs += vd
case other =>
maybeBumpValIdx(other)
defs += other
}

Definitions(
defs.toList,
state.copy(
objectIndex = state.objectIndex + 1,
valIndex = valIdx
)
)
}

/** Wrap trees in an object and add imports from the previous compilations
*
* The resulting structure is something like:
*
* ```
* package <none> {
* object rs$line$nextId {
* import rs$line${i <- 0 until nextId}._
*
* <trees>
* }
* }
* ```
*/
private def wrapped(defs: Definitions, objectTermName: TermName, span: Span): untpd.PackageDef =
inContext(defs.state.context) {
import untpd._

val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
val module = ModuleDef(objectTermName, tmpl)
.withSpan(span)
final def compile(parsed: Parsed)(using state: State): Result[(CompilationUnit, State)] =
assert(!parsed.trees.isEmpty)

PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
given Context = state.context
val unit = ReplCompilationUnit(ctx.source).tap { unit =>
unit.untpdTree = packaged(parsed.trees)
unit.untpdTree.putAttachment(ReplCompiler.ReplState, state)
}

private def createUnit(defs: Definitions, span: Span)(using Context): CompilationUnit = {
val objectName = ctx.source.file.toString
assert(objectName.startsWith(str.REPL_SESSION_LINE))
assert(objectName.endsWith(defs.state.objectIndex.toString))
val objectTermName = ctx.source.file.toString.toTermName
objectNames.update(defs.state.objectIndex, objectTermName)

val unit = new ReplCompilationUnit(ctx.source)
unit.untpdTree = wrapped(defs, objectTermName, span)
unit
}

private def runCompilationUnit(unit: CompilationUnit, state: State): Result[(CompilationUnit, State)] = {
val ctx = state.context
ctx.run.nn.compileUnits(unit :: Nil)
ctx.run.nn.printSummary() // this outputs "2 errors found" like normal - but we might decide that's needlessly noisy for the REPL

if (!ctx.reporter.hasErrors) (unit, state).result
else ctx.reporter.removeBufferedMessages(using ctx).errors
}
ctx.run.nn.printSummary() // "2 errors found"

final def compile(parsed: Parsed)(implicit state: State): Result[(CompilationUnit, State)] = {
assert(!parsed.trees.isEmpty)
val defs = definitions(parsed.trees, state)
val unit = createUnit(defs, Span(0, parsed.trees.last.span.end))(using state.context)
runCompilationUnit(unit, defs.state)
}
val newState = unit.tpdTree.getAttachment(ReplCompiler.ReplState).get
if !ctx.reporter.hasErrors then (unit, newState).result
else ctx.reporter.removeBufferedMessages.errors
end compile

final def typeOf(expr: String)(implicit state: State): Result[String] =
final def typeOf(expr: String)(using state: State): Result[String] =
typeCheck(expr).map { tree =>
given Context = state.context
tree.rhs match {
Expand All @@ -190,7 +105,7 @@ class ReplCompiler extends Compiler {
}
}

def docOf(expr: String)(implicit state: State): Result[String] = inContext(state.context) {
def docOf(expr: String)(using state: State): Result[String] = inContext(state.context) {

/** Extract the "selected" symbol from `tree`.
*
Expand Down Expand Up @@ -237,7 +152,7 @@ class ReplCompiler extends Compiler {
}
}

final def typeCheck(expr: String, errorsAllowed: Boolean = false)(implicit state: State): Result[tpd.ValDef] = {
final def typeCheck(expr: String, errorsAllowed: Boolean = false)(using state: State): Result[tpd.ValDef] = {

def wrapped(expr: String, sourceFile: SourceFile, state: State)(using Context): Result[untpd.PackageDef] = {
def wrap(trees: List[untpd.Tree]): untpd.PackageDef = {
Expand Down Expand Up @@ -300,4 +215,121 @@ class ReplCompiler extends Compiler {
}
}
}
}
object ReplCompiler:
val ReplState: Property.StickyKey[State] = Property.StickyKey()
val objectNames = mutable.Map.empty[Int, TermName]
end ReplCompiler

class ReplCompilationUnit(source: SourceFile) extends CompilationUnit(source):
override def isSuspendable: Boolean = false

/** A placeholder phase that receives parse trees..
*
* It is called "parser" for the convenience of collective muscle memory.
*
* This enables -Vprint:parser.
*/
class Parser extends Phase:
def phaseName: String = "parser"
def run(using Context): Unit = ()
end Parser

/** A phase that assembles wrapped parse trees from user input.
*
* Ths `ReplState` attachment indicates Repl wrapping is required.
*
* This enables -Vprint:repl so that users can see how their code snippet was wrapped.
*/
class ReplPhase extends Phase:
def phaseName: String = "repl"

def run(using Context): Unit =
ctx.compilationUnit.untpdTree match
case pkg @ PackageDef(_, stats) =>
pkg.getAttachment(ReplCompiler.ReplState).foreach {
case given State =>
val defs = definitions(stats)
val res = wrapped(defs, Span(0, stats.last.span.end))
res.putAttachment(ReplCompiler.ReplState, defs.state)
ctx.compilationUnit.untpdTree = res
}
case _ =>
end run

private case class Definitions(stats: List[untpd.Tree], state: State)

private def definitions(trees: List[untpd.Tree])(using Context, State): Definitions =
import untpd.*

// If trees is of the form `{ def1; def2; def3 }` then `List(def1, def2, def3)`
val flattened = trees match {
case List(Block(stats, expr)) =>
if (expr eq EmptyTree) stats // happens when expr is not an expression
else stats :+ expr
case _ =>
trees
}

val state = summon[State]
var valIdx = state.valIndex
val defs = mutable.ListBuffer.empty[Tree]

/** If the user inputs a definition whose name is of the form REPL_RES_PREFIX and a number,
* such as `val res9 = 1`, we bump `valIdx` to avoid name clashes. lampepfl/dotty#3536 */
def maybeBumpValIdx(tree: Tree): Unit = tree match
case apply: Apply => for a <- apply.args do maybeBumpValIdx(a)
case tuple: Tuple => for t <- tuple.trees do maybeBumpValIdx(t)
case patDef: PatDef => for p <- patDef.pats do maybeBumpValIdx(p)
case tree: NameTree => tree.name.show.stripPrefix(str.REPL_RES_PREFIX).toIntOption match
case Some(n) if n >= valIdx => valIdx = n + 1
case _ =>
case _ =>

flattened.foreach {
case expr @ Assign(id: Ident, _) =>
// special case simple reassignment (e.g. x = 3)
// in order to print the new value in the REPL
val assignName = (id.name ++ str.REPL_ASSIGN_SUFFIX).toTermName
val assign = ValDef(assignName, TypeTree(), id).withSpan(expr.span)
defs += expr += assign
case expr if expr.isTerm =>
val resName = (str.REPL_RES_PREFIX + valIdx).toTermName
valIdx += 1
val vd = ValDef(resName, TypeTree(), expr).withSpan(expr.span)
defs += vd
case other =>
maybeBumpValIdx(other)
defs += other
}

Definitions(defs.toList, state.copy(objectIndex = state.objectIndex + 1, valIndex = valIdx))
end definitions

/** Wrap trees in an object and add imports from the previous compilations.
*
* The resulting structure is something like:
*
* ```
* package <none> {
* object rs$line$nextId {
* import rs$line${i <- 0 until nextId}.*
*
* <trees>
* }
* }
* ```
*/
private def wrapped(defs: Definitions, span: Span)(using Context): untpd.PackageDef =
import untpd.*

val objectName = ctx.source.file.toString
assert(objectName.startsWith(str.REPL_SESSION_LINE))
assert(objectName.endsWith(defs.state.objectIndex.toString))
val objectTermName = objectName.toTermName
ReplCompiler.objectNames.update(defs.state.objectIndex, objectTermName)

val tmpl = Template(emptyConstructor, Nil, Nil, EmptyValDef, defs.stats)
val module = ModuleDef(objectTermName, tmpl).withSpan(span)

PackageDef(Ident(nme.EMPTY_PACKAGE), List(module))
end ReplPhase
Loading

0 comments on commit 62ca3fc

Please sign in to comment.