Skip to content

Commit 45984f5

Browse files
dkomanovjohnynek
authored andcommitted
Support JMH for scala 2.12 #295 (#465)
* Upgrade versions of JMH 1.17.4 -> 1.20, asm 5.0.4 -> 6.1.1 * Add reflection-based generator for JMH (make it default)
1 parent 41ac5be commit 45984f5

File tree

2 files changed

+86
-32
lines changed

2 files changed

+86
-32
lines changed

jmh/jmh.bzl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,35 @@ load("//scala:scala.bzl", "scala_binary", "scala_library")
33
def jmh_repositories():
44
native.maven_jar(
55
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_core",
6-
artifact = "org.openjdk.jmh:jmh-core:1.17.4",
7-
sha1 = "126d989b196070a8b3653b5389e602a48fe6bb2f",
6+
artifact = "org.openjdk.jmh:jmh-core:1.20",
7+
sha1 = "5f9f9839bda2332e9acd06ce31ad94afa7d6d447",
88
)
99
native.bind(
1010
name = 'io_bazel_rules_scala/dependency/jmh/jmh_core',
1111
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_core//jar',
1212
)
1313
native.maven_jar(
1414
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm",
15-
artifact = "org.openjdk.jmh:jmh-generator-asm:1.17.4",
16-
sha1 = "c85c3d8cfa194872b260e89770d41e2084ce2cb6",
15+
artifact = "org.openjdk.jmh:jmh-generator-asm:1.20",
16+
sha1 = "3c43040e08ae68905657a375e669f11a7352f9db",
1717
)
1818
native.bind(
1919
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_asm',
2020
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_asm//jar',
2121
)
2222
native.maven_jar(
2323
name = "io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection",
24-
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.17.4",
25-
sha1 = "f75a7823c9fcf03feed6d74aa44ea61fc70a8439",
24+
artifact = "org.openjdk.jmh:jmh-generator-reflection:1.20",
25+
sha1 = "f2154437b42426a48d5dac0b3df59002f86aed26",
2626
)
2727
native.bind(
2828
name = 'io_bazel_rules_scala/dependency/jmh/jmh_generator_reflection',
2929
actual = '@io_bazel_rules_scala_org_openjdk_jmh_jmh_generator_reflection//jar',
3030
)
3131
native.maven_jar(
3232
name = "io_bazel_rules_scala_org_ows2_asm_asm",
33-
artifact = "org.ow2.asm:asm:5.0.4",
34-
sha1 = "0da08b8cce7bbf903602a25a3a163ae252435795",
33+
artifact = "org.ow2.asm:asm:6.1.1",
34+
sha1 = "264754515362d92acd39e8d40395f6b8dee7bc08",
3535
)
3636
native.bind(
3737
name = "io_bazel_rules_scala/dependency/jmh/org_ows2_asm_asm",
@@ -78,14 +78,15 @@ def _scala_generate_benchmark(ctx):
7878
outputs = [ctx.outputs.src_jar, ctx.outputs.resource_jar],
7979
inputs = [class_jar] + classpath,
8080
executable = ctx.executable._generator,
81-
arguments = [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
81+
arguments = [ctx.attr.generator_type] + [f.path for f in [class_jar, ctx.outputs.src_jar, ctx.outputs.resource_jar] + classpath],
8282
progress_message = "Generating benchmark code for %s" % ctx.label,
8383
)
8484

8585
scala_generate_benchmark = rule(
8686
implementation = _scala_generate_benchmark,
8787
attrs = {
8888
"src": attr.label(allow_single_file=True, mandatory=True),
89+
"generator_type": attr.string(default='reflection', mandatory=False),
8990
"_generator": attr.label(executable=True, cfg="host", default=Label("//src/scala/io/bazel/rules_scala/jmh_support:benchmark_generator"))
9091
},
9192
outputs = {
@@ -98,6 +99,7 @@ def scala_benchmark_jmh(**kw):
9899
name = kw["name"]
99100
deps = kw.get("deps", [])
100101
srcs = kw["srcs"]
102+
generator_type = kw.get("generator_type", "reflection")
101103
lib = "%s_generator" % name
102104
scalacopts = kw.get("scalacopts", [])
103105
main_class = kw.get("main_class", "org.openjdk.jmh.Main")
@@ -115,7 +117,7 @@ def scala_benchmark_jmh(**kw):
115117
)
116118

117119
codegen = name + "_codegen"
118-
scala_generate_benchmark(name=codegen, src=lib)
120+
scala_generate_benchmark(name=codegen, src=lib, generator_type=generator_type)
119121
compiled_lib = name + "_compiled_benchmark_lib"
120122
scala_library(
121123
name = compiled_lib,

src/scala/io/bazel/rules_scala/jmh_support/BenchmarkGenerator.scala

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ import java.net.URLClassLoader
44

55
import scala.annotation.tailrec
66
import scala.collection.JavaConverters._
7-
8-
import org.openjdk.jmh.generators.core.{ BenchmarkGenerator => JMHGenerator, FileSystemDestination }
7+
import org.openjdk.jmh.generators.core.{FileSystemDestination, GeneratorSource, BenchmarkGenerator => JMHGenerator}
98
import org.openjdk.jmh.generators.asm.ASMGeneratorSource
10-
import org.openjdk.jmh.runner.{ Runner, RunnerException }
11-
import org.openjdk.jmh.runner.options.{ Options, OptionsBuilder }
12-
9+
import org.openjdk.jmh.generators.reflection.RFGeneratorSource
10+
import org.openjdk.jmh.runner.{Runner, RunnerException}
11+
import org.openjdk.jmh.runner.options.{Options, OptionsBuilder}
1312
import java.net.URI
13+
1414
import scala.collection.JavaConverters._
15-
import java.nio.file.{Files, FileSystems, Path}
15+
import java.nio.file.{FileSystems, Files, Path, Paths}
1616

1717
import io.bazel.rulesscala.jar.JarCreator
1818

@@ -27,7 +27,14 @@ import io.bazel.rulesscala.jar.JarCreator
2727
*/
2828
object BenchmarkGenerator {
2929

30-
case class BenchmarkGeneratorArgs(
30+
private sealed trait GeneratorType
31+
32+
private case object ReflectionGenerator extends GeneratorType
33+
34+
private case object AsmGenerator extends GeneratorType
35+
36+
private case class BenchmarkGeneratorArgs(
37+
generatorType: GeneratorType,
3138
inputJar: Path,
3239
resultSourceJar: Path,
3340
resultResourceJar: Path,
@@ -37,6 +44,7 @@ object BenchmarkGenerator {
3744
def main(argv: Array[String]): Unit = {
3845
val args = parseArgs(argv)
3946
generateJmhBenchmark(
47+
args.generatorType,
4048
args.resultSourceJar,
4149
args.resultResourceJar,
4250
args.inputJar,
@@ -47,17 +55,18 @@ object BenchmarkGenerator {
4755
private def parseArgs(argv: Array[String]): BenchmarkGeneratorArgs = {
4856
if (argv.length < 3) {
4957
System.err.println(
50-
"Usage: BenchmarkGenerator INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
58+
"Usage: BenchmarkGenerator GENERATOR_TYPE INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
5159
)
5260
System.exit(1)
5361
}
5462
val fs = FileSystems.getDefault
5563

5664
BenchmarkGeneratorArgs(
57-
fs.getPath(argv(0)),
65+
parseGeneratorType(argv(0)),
5866
fs.getPath(argv(1)),
5967
fs.getPath(argv(2)),
60-
argv.slice(3, argv.length).map { s => fs.getPath(s) }.toList
68+
fs.getPath(argv(3)),
69+
argv.slice(4, argv.length).map { s => fs.getPath(s) }.toList
6170
)
6271
}
6372

@@ -88,13 +97,13 @@ object BenchmarkGenerator {
8897
}
8998

9099
// Courtesy of Doug Tangren (https://groups.google.com/forum/#!topic/simple-build-tool/CYeLHcJjHyA)
91-
private def withClassLoader[A](cp: Seq[Path])(f: => A): A = {
100+
private def withClassLoader[A](cp: Seq[Path])(f: ClassLoader => A): A = {
92101
val originalLoader = Thread.currentThread.getContextClassLoader
93102
val jmhLoader = classOf[JMHGenerator].getClassLoader
94103
val classLoader = new URLClassLoader(cp.map(_.toUri.toURL).toArray, jmhLoader)
95104
try {
96105
Thread.currentThread.setContextClassLoader(classLoader)
97-
f
106+
f(classLoader)
98107
} finally {
99108
Thread.currentThread.setContextClassLoader(originalLoader)
100109
}
@@ -119,6 +128,7 @@ object BenchmarkGenerator {
119128
}
120129

121130
private def generateJmhBenchmark(
131+
generatorType: GeneratorType,
122132
sourceJarOut: Path,
123133
resourceJarOut: Path,
124134
benchmarkJarPath: Path,
@@ -131,17 +141,26 @@ object BenchmarkGenerator {
131141
tmpResourceDir.toFile.mkdir()
132142
tmpSourceDir.toFile.mkdir()
133143

134-
withClassLoader(benchmarkJarPath :: classpath) {
135-
val source = new ASMGeneratorSource
136-
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
137-
val generator = new JMHGenerator
138-
139-
collectClassesFromJar(benchmarkJarPath).foreach { path =>
140-
// this would fail due to https://github.com/bazelbuild/rules_scala/issues/295
141-
// let's throw a useful message instead
142-
sys.error("jmh in rules_scala doesn't work with Java 8 bytecode: https://github.com/bazelbuild/rules_scala/issues/295")
143-
source.processClass(Files.newInputStream(path))
144+
withClassLoader(benchmarkJarPath :: classpath) { isolatedClassLoader =>
145+
146+
val source: GeneratorSource = generatorType match {
147+
case AsmGenerator =>
148+
val generatorSource = new ASMGeneratorSource
149+
generatorSource.processClasses(collectClassesFromJar(benchmarkJarPath).map(_.toFile).asJavaCollection)
150+
generatorSource
151+
152+
case ReflectionGenerator =>
153+
val generatorSource = new RFGeneratorSource
154+
generatorSource.processClasses(
155+
collectClassesFromJar(benchmarkJarPath)
156+
.flatMap(classByPath(_, isolatedClassLoader))
157+
.asJavaCollection
158+
)
159+
generatorSource
144160
}
161+
162+
val generator = new JMHGenerator
163+
val destination = new FileSystemDestination(tmpResourceDir.toFile, tmpSourceDir.toFile)
145164
generator.generate(source, destination)
146165
generator.complete(source, destination)
147166
if (destination.hasErrors) {
@@ -156,6 +175,39 @@ object BenchmarkGenerator {
156175
}
157176
}
158177

178+
private def classByPath(path: Path, cl: ClassLoader): Option[Class[_]] = {
179+
val separator = path.getFileSystem.getSeparator
180+
var s = path.toString
181+
.stripPrefix(separator)
182+
.stripSuffix(".class")
183+
.replace(separator, ".")
184+
185+
var index = -1
186+
do {
187+
s = s.substring(index + 1)
188+
try {
189+
return Some(Class.forName(s, false, cl))
190+
} catch {
191+
case _: ClassNotFoundException =>
192+
// ignore and try next one
193+
index = s.indexOf('.')
194+
}
195+
} while (index != -1)
196+
197+
log(s"Failed to find class for path $path")
198+
None
199+
}
200+
201+
private def parseGeneratorType(s: String): GeneratorType = {
202+
if ("asm".equalsIgnoreCase(s)) {
203+
AsmGenerator
204+
} else if ("reflection".equalsIgnoreCase(s)) {
205+
ReflectionGenerator
206+
} else {
207+
throw new IllegalArgumentException(s"unknown generator_type: $s")
208+
}
209+
}
210+
159211
private def log(str: String): Unit = {
160212
System.err.println(s"JMH benchmark generation: $str")
161213
}

0 commit comments

Comments
 (0)