Skip to content

Commit 2da90a0

Browse files
committed
Addressed comments
1 parent 11e6975 commit 2da90a0

File tree

3 files changed

+78
-40
lines changed

3 files changed

+78
-40
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,13 @@ class Definitions {
585585

586586
lazy val StringContextType: TypeRef = ctx.requiredClassRef("scala.StringContext")
587587
def StringContextClass(implicit ctx: Context) = StringContextType.symbol.asClass
588+
lazy val StringContextSR = StringContextClass.requiredMethodRef(nme.s)
589+
def StringContextS(implicit ctx: Context) = StringContextSR.symbol
590+
lazy val StringContextRawR = StringContextClass.requiredMethodRef(nme.raw_)
591+
def StringContextRaw(implicit ctx: Context) = StringContextRawR.symbol
588592
def StringContextModule(implicit ctx: Context) = StringContextClass.companionModule
593+
lazy val StringContextModule_applyR = StringContextModule.requiredMethodRef(nme.apply)
594+
def StringContextModule_apply(implicit ctx: Context) = StringContextModule_applyR.symbol
589595

590596
lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction")
591597
def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass

compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,52 +9,84 @@ import dotty.tools.dotc.core.Symbols._
99
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
1010

1111
/**
12-
* Created by wojtekswiderski on 2018-01-24.
12+
* MiniPhase to transform s and raw string interpolators from using StringContext to string
13+
* concatenation. Since string concatenation uses the Java String builder, we get a performance
14+
* improvement in terms of these two interpolators.
15+
*
16+
* More info here:
17+
* https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd
1318
*/
1419
class StringInterpolatorOpt extends MiniPhase {
1520
import tpd._
1621

1722
override def phaseName: String = "stringInterpolatorOpt"
1823

24+
/** Matches a list of constant literals */
25+
private object Literals {
26+
def unapply(tree: SeqLiteral)(implicit ctx: Context): Option[List[Literal]] = {
27+
tree.elems match {
28+
case literals if literals.forall(_.isInstanceOf[Literal]) =>
29+
Some(literals.map(_.asInstanceOf[Literal]))
30+
case _ => None
31+
}
32+
}
33+
}
34+
35+
/** Matches an s or raw string interpolator */
36+
private object SOrRawInterpolator {
37+
def unapply(tree: Tree)(implicit ctx: Context): Option[(List[Literal], List[Tree])] = {
38+
if (tree.symbol.eq(defn.StringContextRaw) || tree.symbol.eq(defn.StringContextS)) {
39+
tree match {
40+
case Apply(Select(Apply(strContextApply, List(Literals(strs))), _),
41+
List(SeqLiteral(elems, _)))
42+
if strContextApply.symbol.eq(defn.StringContextModule_apply) &&
43+
elems.length == strs.length - 1 =>
44+
Some(strs, elems)
45+
case _ => None
46+
}
47+
} else None
48+
}
49+
}
50+
51+
/**
52+
* Match trees that resemble s and raw string interpolations. In the case of the s
53+
* interpolator, escapes the string constants. Exposes the string constants as well as
54+
* the variable references.
55+
*/
1956
private object StringContextIntrinsic {
20-
def unapply(tree: Apply)(implicit ctx: Context): Option[(List[Tree], List[Tree])] = {
57+
def unapply(tree: Apply)(implicit ctx: Context): Option[(List[Literal], List[Tree])] = {
2158
tree match {
22-
case Apply(Select(Apply(Select(ident, nme.apply), List(SeqLiteral(strs, _))), fn),
23-
List(SeqLiteral(elems, _))) =>
24-
if (ident.symbol.eq(defn.StringContextModule) && strs.forall(_.isInstanceOf[Literal])
25-
&& elems.length == strs.length - 1) {
26-
if (fn == nme.raw_) Some(strs, elems)
27-
else if (fn == nme.s) {
28-
try {
29-
val escapedStrs = strs.mapConserve { str =>
30-
val strValue = str.asInstanceOf[Literal].const.stringValue
31-
val escapedValue = StringContext.processEscapes(strValue)
32-
cpy.Literal(str)(Constant(escapedValue))
33-
}
34-
Some(escapedStrs, elems)
35-
} catch {
36-
case _: StringContext.InvalidEscapeException => None
59+
case SOrRawInterpolator(strs, elems) =>
60+
if (tree.symbol == defn.StringContextRaw) Some(strs, elems)
61+
else { // tree.symbol == defn.StringContextS
62+
try {
63+
val escapedStrs = strs.map { str =>
64+
val escapedValue = StringContext.processEscapes(str.const.stringValue)
65+
cpy.Literal(str)(Constant(escapedValue))
3766
}
38-
} else None
39-
} else None
67+
Some(escapedStrs, elems)
68+
} catch {
69+
case _: StringContext.InvalidEscapeException => None
70+
}
71+
}
4072
case _ => None
4173
}
4274
}
4375
}
4476

4577
override def transformApply(tree: Apply)(implicit ctx: Context): Tree = {
4678
tree match {
47-
case StringContextIntrinsic(strs: List[Tree], elems: List[Tree]) =>
79+
case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) =>
4880
val stri = strs.iterator
4981
val elemi = elems.iterator
50-
var result = stri.next
82+
var result: Tree = stri.next
5183
def concat(tree: Tree): Unit = {
5284
result = result.select(defn.String_+).appliedTo(tree)
5385
}
5486
while (elemi.hasNext) {
5587
concat(elemi.next)
5688
val str = stri.next
57-
if (!str.asInstanceOf[Literal].const.stringValue.isEmpty) concat(str)
89+
if (!str.const.stringValue.isEmpty) concat(str)
5890
}
5991
result
6092
case _ => tree

compiler/test/dotty/tools/backend/jvm/StringInterpolatorOptTest.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ class StringInterpolatorOptTest extends DottyBytecodeTest {
99
@Test def testRawInterpolator = {
1010
val source =
1111
"""
12-
|class Foo {
13-
| val one = 1
14-
| val two = "two"
15-
| val three = 3.0
16-
|
17-
| def meth1: String = raw"$one plus $two$three\n"
18-
| def meth2: String = "" + one + " plus " + two + three + "\\n"
19-
|}
20-
""".stripMargin
12+
|class Foo {
13+
| val one = 1
14+
| val two = "two"
15+
| val three = 3.0
16+
|
17+
| def meth1: String = raw"$one plus $two$three\n"
18+
| def meth2: String = "" + one + " plus " + two + three + "\\n"
19+
|}
20+
""".stripMargin
2121

2222
checkBCode(source) { dir =>
2323
val clsIn = dir.lookupName("Foo.class", directory = false).input
@@ -37,15 +37,15 @@ class StringInterpolatorOptTest extends DottyBytecodeTest {
3737
@Test def testSInterpolator = {
3838
val source =
3939
"""
40-
|class Foo {
41-
| val one = 1
42-
| val two = "two"
43-
| val three = 3.0
44-
|
45-
| def meth1: String = s"$one plus $two$three\n"
46-
| def meth2: String = "" + one + " plus " + two + three + "\n"
47-
|}
48-
""".stripMargin
40+
|class Foo {
41+
| val one = 1
42+
| val two = "two"
43+
| val three = 3.0
44+
|
45+
| def meth1: String = s"$one plus $two$three\n"
46+
| def meth2: String = "" + one + " plus " + two + three + "\n"
47+
|}
48+
""".stripMargin
4949

5050
checkBCode(source) { dir =>
5151
val clsIn = dir.lookupName("Foo.class", directory = false).input

0 commit comments

Comments
 (0)