@@ -4,15 +4,15 @@ import java.net.URLClassLoader
44
55import scala .annotation .tailrec
66import scala .collection .JavaConverters ._
7-
8- import org .openjdk .jmh .generators .core .{ BenchmarkGenerator => JMHGenerator , FileSystemDestination }
7+ import org .openjdk .jmh .generators .core .{FileSystemDestination , GeneratorSource , BenchmarkGenerator => JMHGenerator }
98import org .openjdk .jmh .generators .asm .ASMGeneratorSource
10- import org .openjdk .jmh .runner .{ Runner , RunnerException }
11- import org .openjdk .jmh .runner .options .{ Options , OptionsBuilder }
12-
9+ import org .openjdk .jmh .generators . reflection . RFGeneratorSource
10+ import org .openjdk .jmh .runner .{ Runner , RunnerException }
11+ import org . openjdk . jmh . runner . options .{ Options , OptionsBuilder }
1312import java .net .URI
13+
1414import scala .collection .JavaConverters ._
15- import java .nio .file .{Files , FileSystems , Path }
15+ import java .nio .file .{FileSystems , Files , Path , Paths }
1616
1717import io .bazel .rulesscala .jar .JarCreator
1818
@@ -27,7 +27,14 @@ import io.bazel.rulesscala.jar.JarCreator
2727 */
2828object BenchmarkGenerator {
2929
30- case class BenchmarkGeneratorArgs (
30+ private sealed trait GeneratorType
31+
32+ private case object ReflectionGenerator extends GeneratorType
33+
34+ private case object AsmGenerator extends GeneratorType
35+
36+ private case class BenchmarkGeneratorArgs (
37+ generatorType : GeneratorType ,
3138 inputJar : Path ,
3239 resultSourceJar : Path ,
3340 resultResourceJar : Path ,
@@ -37,6 +44,7 @@ object BenchmarkGenerator {
3744 def main (argv : Array [String ]): Unit = {
3845 val args = parseArgs(argv)
3946 generateJmhBenchmark(
47+ args.generatorType,
4048 args.resultSourceJar,
4149 args.resultResourceJar,
4250 args.inputJar,
@@ -47,17 +55,18 @@ object BenchmarkGenerator {
4755 private def parseArgs (argv : Array [String ]): BenchmarkGeneratorArgs = {
4856 if (argv.length < 3 ) {
4957 System .err.println(
50- " Usage: BenchmarkGenerator INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
58+ " Usage: BenchmarkGenerator GENERATOR_TYPE INPUT_JAR RESULT_JAR RESOURCE_JAR [CLASSPATH_ELEMENT] [CLASSPATH_ELEMENT...]"
5159 )
5260 System .exit(1 )
5361 }
5462 val fs = FileSystems .getDefault
5563
5664 BenchmarkGeneratorArgs (
57- fs.getPath (argv(0 )),
65+ parseGeneratorType (argv(0 )),
5866 fs.getPath(argv(1 )),
5967 fs.getPath(argv(2 )),
60- argv.slice(3 , argv.length).map { s => fs.getPath(s) }.toList
68+ fs.getPath(argv(3 )),
69+ argv.slice(4 , argv.length).map { s => fs.getPath(s) }.toList
6170 )
6271 }
6372
@@ -88,13 +97,13 @@ object BenchmarkGenerator {
8897 }
8998
9099 // Courtesy of Doug Tangren (https://groups.google.com/forum/#!topic/simple-build-tool/CYeLHcJjHyA)
91- private def withClassLoader [A ](cp : Seq [Path ])(f : => A ): A = {
100+ private def withClassLoader [A ](cp : Seq [Path ])(f : ClassLoader => A ): A = {
92101 val originalLoader = Thread .currentThread.getContextClassLoader
93102 val jmhLoader = classOf [JMHGenerator ].getClassLoader
94103 val classLoader = new URLClassLoader (cp.map(_.toUri.toURL).toArray, jmhLoader)
95104 try {
96105 Thread .currentThread.setContextClassLoader(classLoader)
97- f
106+ f(classLoader)
98107 } finally {
99108 Thread .currentThread.setContextClassLoader(originalLoader)
100109 }
@@ -119,6 +128,7 @@ object BenchmarkGenerator {
119128 }
120129
121130 private def generateJmhBenchmark (
131+ generatorType : GeneratorType ,
122132 sourceJarOut : Path ,
123133 resourceJarOut : Path ,
124134 benchmarkJarPath : Path ,
@@ -131,17 +141,26 @@ object BenchmarkGenerator {
131141 tmpResourceDir.toFile.mkdir()
132142 tmpSourceDir.toFile.mkdir()
133143
134- withClassLoader(benchmarkJarPath :: classpath) {
135- val source = new ASMGeneratorSource
136- val destination = new FileSystemDestination (tmpResourceDir.toFile, tmpSourceDir.toFile)
137- val generator = new JMHGenerator
138-
139- collectClassesFromJar(benchmarkJarPath).foreach { path =>
140- // this would fail due to https://github.com/bazelbuild/rules_scala/issues/295
141- // let's throw a useful message instead
142- sys.error(" jmh in rules_scala doesn't work with Java 8 bytecode: https://github.com/bazelbuild/rules_scala/issues/295" )
143- source.processClass(Files .newInputStream(path))
144+ withClassLoader(benchmarkJarPath :: classpath) { isolatedClassLoader =>
145+
146+ val source : GeneratorSource = generatorType match {
147+ case AsmGenerator =>
148+ val generatorSource = new ASMGeneratorSource
149+ generatorSource.processClasses(collectClassesFromJar(benchmarkJarPath).map(_.toFile).asJavaCollection)
150+ generatorSource
151+
152+ case ReflectionGenerator =>
153+ val generatorSource = new RFGeneratorSource
154+ generatorSource.processClasses(
155+ collectClassesFromJar(benchmarkJarPath)
156+ .flatMap(classByPath(_, isolatedClassLoader))
157+ .asJavaCollection
158+ )
159+ generatorSource
144160 }
161+
162+ val generator = new JMHGenerator
163+ val destination = new FileSystemDestination (tmpResourceDir.toFile, tmpSourceDir.toFile)
145164 generator.generate(source, destination)
146165 generator.complete(source, destination)
147166 if (destination.hasErrors) {
@@ -156,6 +175,39 @@ object BenchmarkGenerator {
156175 }
157176 }
158177
178+ private def classByPath (path : Path , cl : ClassLoader ): Option [Class [_]] = {
179+ val separator = path.getFileSystem.getSeparator
180+ var s = path.toString
181+ .stripPrefix(separator)
182+ .stripSuffix(" .class" )
183+ .replace(separator, " ." )
184+
185+ var index = - 1
186+ do {
187+ s = s.substring(index + 1 )
188+ try {
189+ return Some (Class .forName(s, false , cl))
190+ } catch {
191+ case _ : ClassNotFoundException =>
192+ // ignore and try next one
193+ index = s.indexOf('.' )
194+ }
195+ } while (index != - 1 )
196+
197+ log(s " Failed to find class for path $path" )
198+ None
199+ }
200+
201+ private def parseGeneratorType (s : String ): GeneratorType = {
202+ if (" asm" .equalsIgnoreCase(s)) {
203+ AsmGenerator
204+ } else if (" reflection" .equalsIgnoreCase(s)) {
205+ ReflectionGenerator
206+ } else {
207+ throw new IllegalArgumentException (s " unknown generator_type: $s" )
208+ }
209+ }
210+
159211 private def log (str : String ): Unit = {
160212 System .err.println(s " JMH benchmark generation: $str" )
161213 }
0 commit comments