Skip to content

Add a Var abstraction in QuoteUtils #7035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ class BootstrappedOnlyCompilationTests extends ParallelTesting {
implicit val testGroup: TestGroup = TestGroup("runWithCompiler")
aggregateTests(
compileFilesInDir("tests/run-with-compiler", withCompilerOptions),
compileDir("tests/run-with-compiler-custom-args/tasty-interpreter", withCompilerOptions),
compileFile("tests/run-with-compiler-custom-args/staged-streams_1.scala", withCompilerOptions without "-Yno-deep-subtypes")
compileDir("tests/run-with-compiler-custom-args/tasty-interpreter", withCompilerOptions)
).checkRuns()
}

Expand Down
48 changes: 48 additions & 0 deletions library/src-bootstrapped/scala/quoted/util/Var.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package scala.quoted
package util

/** An abstraction for variable definition to use in a quoted program.
* It decouples the operations of get and update, if needed to be spliced separately.
*/
sealed trait Var[T] {

// Retrieves the value of the variable
def get given QuoteContext: Expr[T]

// Update the variable with the expression of a value (`e` corresponds to the RHS of variable assignment `x = e`)
def update(e: Expr[T]) given QuoteContext: Expr[Unit]
}

object Var {
/** Create a variable initialized with `init` and used in `body`.
* `body` recieves a `Var[T]` argument which exposes `get` and `update`.
*
* `var`('(7)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@biboudis, you forgot to adapt the comment. It should be Var(.... Also need to update the quote and splice syntax

* x => '{
* while(0 < ~x)
* ~x.update('(~x - 1))
* ~x.get
* }
* }
*
* will create the equivalent of
*
* '{
* var x = 7
* while (0 < x)
* x = x - 1
* x
* }
*/
def apply[T: Type, U: Type](init: Expr[T])(body: Var[T] => Expr[U]) given QuoteContext: Expr[U] = '{
var x = $init
${
body(
new Var[T] {
def get given QuoteContext: Expr[T] = 'x
def update(e: Expr[T]) given QuoteContext: Expr[Unit] = '{ x = $e }
}
)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,13 @@
import scala.quoted._
import scala.quoted.util._
import given scala.quoted.autolift._

/**
* Port of the strymonas library as described in O. Kiselyov et al., Stream fusion, to completeness (POPL 2017)
*/

object Test {

// TODO: remove as it exists in Quoted Lib
sealed trait Var[T] {
def get given QuoteContext: Expr[T]
def update(x: Expr[T]) given QuoteContext: Expr[Unit]
}

object Var {
def apply[T: Type, U: Type](init: Expr[T])(body: Var[T] => Expr[U]) given QuoteContext: Expr[U] = '{
var x = $init
${
body(
new Var[T] {
def get given QuoteContext: Expr[T] = 'x
def update(e: Expr[T]) given QuoteContext: Expr[Unit] = '{ x = $e }
}
)
}
}
}
type E[T] = given QuoteContext => Expr[T]

/*** Producer represents a linear production of values with a loop structure.
*
Expand Down Expand Up @@ -61,27 +43,29 @@ object Test {
* @param k the continuation that is invoked after the new state is defined in the body of `init`
* @return expr value of unit per the CPS-encoding
*/
def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit]
def init(k: St => Expr[Unit]): E[Unit]

/** Step method that defines the transformation of data.
*
* @param st the state needed for this iteration step
* @param k the continuation that accepts each element and proceeds with the step-wise processing
* @return expr value of unit per the CPS-encoding
*/
def step(st: St, k: (A => Expr[Unit])) given QuoteContext: Expr[Unit]
def step(st: St, k: (A => Expr[Unit])): E[Unit]

/** The condition that checks for termination
*
* @param st the state needed for this iteration check
* @return the expression for a boolean
*/
def hasNext(st: St) given QuoteContext: Expr[Boolean]
def hasNext(st: St): E[Boolean]
}

trait Cardinality
case object AtMost1 extends Cardinality
case object Many extends Cardinality
enum Cardinality {
case AtMost1
case Many
}
import Cardinality._

trait StagedStream[A]
case class Linear[A](producer: Producer[A]) extends StagedStream[A]
Expand All @@ -98,19 +82,17 @@ object Test {
* @tparam W the type of the accumulator
* @return
*/
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])) given QuoteContext: Expr[W] = {
Var(z) { s: Var[W] => '{
${
foldRaw[Expr[A]]((a: Expr[A]) => '{
${ s.update(f(s.get, a)) }
}, stream)
}
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])): E[W] = {
Var(z) { s =>
'{
${ foldRaw[Expr[A]]((a: Expr[A]) => s.update(f(s.get, a)), stream) }

${ s.get }
}
}
}

private def foldRaw[A](consumer: A => Expr[Unit], stream: StagedStream[A]) given QuoteContext: Expr[Unit] = {
private def foldRaw[A](consumer: A => Expr[Unit], stream: StagedStream[A]): E[Unit] = {
stream match {
case Linear(producer) => {
producer.card match {
Expand Down Expand Up @@ -166,15 +148,15 @@ object Test {
type St = producer.St
val card = producer.card

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def init(k: St => Expr[Unit]): E[Unit] = {
producer.init(k)
}

def step(st: St, k: (B => Expr[Unit])) given QuoteContext: Expr[Unit] = {
def step(st: St, k: (B => Expr[Unit])): E[Unit] = {
producer.step(st, el => f(el)(k))
}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
producer.hasNext(st)
}
}
Expand Down Expand Up @@ -229,13 +211,13 @@ object Test {
type St = Expr[A]
val card = AtMost1

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] =
def init(k: St => Expr[Unit]): E[Unit] =
k(a)

def step(st: St, k: (Expr[A] => Expr[Unit])) given QuoteContext: Expr[Unit] =
def step(st: St, k: (Expr[A] => Expr[Unit])): E[Unit] =
k(st)

def hasNext(st: St) given QuoteContext: Expr[Boolean] =
def hasNext(st: St): E[Boolean] =
pred(st)
}

Expand All @@ -259,13 +241,13 @@ object Test {
type St = producer.St
val card = producer.card

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] =
def init(k: St => Expr[Unit]): E[Unit] =
producer.init(k)

def step(st: St, k: (A => Expr[Unit])) given QuoteContext: Expr[Unit] =
def step(st: St, k: (A => Expr[Unit])): E[Unit] =
producer.step(st, el => k(el))

def hasNext(st: St) given QuoteContext: Expr[Boolean] =
def hasNext(st: St): E[Boolean] =
f(producer.hasNext(st))
}
case AtMost1 => producer
Expand All @@ -292,22 +274,20 @@ object Test {
type St = (Var[Int], producer.St)
val card = producer.card

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def init(k: St => Expr[Unit]): E[Unit] = {
producer.init(st => {
Var(n) { counter =>
k(counter, st)
}
})
}

def step(st: St, k: (((Var[Int], A)) => Expr[Unit])) given QuoteContext: Expr[Unit] = {
def step(st: St, k: (((Var[Int], A)) => Expr[Unit])): E[Unit] = {
val (counter, currentState) = st
producer.step(currentState, el => '{
${k((counter, el))}
})
producer.step(currentState, el => k((counter, el)))
}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
val (counter, currentState) = st
producer.card match {
case Many => '{ ${counter.get} > 0 && ${producer.hasNext(currentState)} }
Expand Down Expand Up @@ -365,7 +345,7 @@ object Test {
pushLinear[A = Expr[A], C = B](producer1, producer2, nestf2)

case (Nested(producer1, nestf1), Linear(producer2)) =>
mapRaw[(B, Expr[A]), (Expr[A], B)]((t => k => '{ ${k((t._2, t._1))} }), pushLinear[A = B, C = Expr[A]](producer2, producer1, nestf1))
mapRaw[(B, Expr[A]), (Expr[A], B)]((t => k => k((t._2, t._1))), pushLinear[A = B, C = Expr[A]](producer2, producer1, nestf1))

case (Nested(producer1, nestf1), Nested(producer2, nestf2)) =>
zipRaw[A, B](Linear(makeLinear(stream1)), stream2)
Expand Down Expand Up @@ -441,7 +421,7 @@ object Test {
* @param k the continuation that consumes a variable.
* @return the quote of the orchestrated code that will be executed as
*/
def makeAdvanceFunction[A](nadv: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]) given QuoteContext: Expr[Unit] = {
def makeAdvanceFunction[A](nadv: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]): E[Unit] = {
stream match {
case Linear(producer) =>
producer.card match {
Expand Down Expand Up @@ -482,7 +462,7 @@ object Test {
type St = (Var[Boolean], Var[A], Var[Unit => Unit])
val card: Cardinality = Many

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def init(k: St => Expr[Unit]): E[Unit] = {
producer.init(st =>
Var('{ (_: Unit) => ()}){ nadv => {
Var('{ true }) { hasNext => {
Expand All @@ -506,7 +486,7 @@ object Test {
}})
}

def step(st: St, k: Expr[A] => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def step(st: St, k: Expr[A] => Expr[Unit]): E[Unit] = {
val (flag, current, nadv) = st
'{
var el = ${current.get}
Expand All @@ -517,7 +497,7 @@ object Test {

}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
val (flag, _, _) = st
flag.get
}
Expand All @@ -532,19 +512,19 @@ object Test {
type St = (Var[Boolean], producer.St, nestedProducer.St)
val card: Cardinality = Many

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
producer.init(s1 => '{ ${nestedProducer.init(s2 =>
def init(k: St => Expr[Unit]): E[Unit] = {
producer.init(s1 => nestedProducer.init(s2 =>
Var(producer.hasNext(s1)) { flag =>
k((flag, s1, s2))
})}})
}))
}

def step(st: St, k: ((Var[Boolean], producer.St, B)) => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def step(st: St, k: ((Var[Boolean], producer.St, B)) => Expr[Unit]): E[Unit] = {
val (flag, s1, s2) = st
nestedProducer.step(s2, b => '{ ${k((flag, s1, b))} })
nestedProducer.step(s2, b => k((flag, s1, b)))
}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
val (flag, s1, s2) = st
'{ ${flag.get} && ${nestedProducer.hasNext(s2)} }
}
Expand All @@ -554,7 +534,7 @@ object Test {
val (flag, s1, b) = t

mapRaw[C, (A, C)]((c => k => '{
${producer.step(s1, a => '{ ${k((a, c))} })}
${producer.step(s1, a => k((a, c)))}
${flag.update(producer.hasNext(s1))}
}), addTerminationCondition((b_flag: Expr[Boolean]) => '{ ${flag.get} && $b_flag }, nestedf(b)))
})
Expand All @@ -567,16 +547,16 @@ object Test {
type St = (producer1.St, producer2.St)
val card: Cardinality = Many

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def init(k: St => Expr[Unit]): E[Unit] = {
producer1.init(s1 => producer2.init(s2 => k((s1, s2)) ))
}

def step(st: St, k: ((A, B)) => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def step(st: St, k: ((A, B)) => Expr[Unit]): E[Unit] = {
val (s1, s2) = st
producer1.step(s1, el1 => producer2.step(s2, el2 => k((el1, el2)) ))
}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
val (s1, s2) = st
'{ ${producer1.hasNext(s1)} && ${producer2.hasNext(s2)} }
}
Expand All @@ -597,15 +577,15 @@ object Test {

val card = Many

def init(k: St => Expr[Unit]) given QuoteContext: Expr[Unit] = {
def init(k: St => Expr[Unit]): E[Unit] = {
Var('{($arr).length}) { n =>
Var(0){ i =>
k((i, n, arr))
}
}
}

def step(st: St, k: (Expr[A] => Expr[Unit])) given QuoteContext: Expr[Unit] = {
def step(st: St, k: (Expr[A] => Expr[Unit])): E[Unit] = {
val (i, _, arr) = st
'{
val el = ($arr).apply(${i.get})
Expand All @@ -614,7 +594,7 @@ object Test {
}
}

def hasNext(st: St) given QuoteContext: Expr[Boolean] = {
def hasNext(st: St): E[Boolean] = {
val (i, n, _) = st
'{
(${i.get} < ${n.get})
Expand Down