@@ -425,19 +425,20 @@ def _path_is_absolute(path):
425425
426426 return False
427427
428- def _write_launcher (ctx , rjars , main_class , jvm_flags , args = "" , wrapper_preamble = "" ):
429- runfiles_root = "${TEST_SRCDIR}/%s" % ctx .workspace_name
430- # RUNPATH is defined here:
431- # https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
432- classpath = ":" .join (["${RUNPATH}%s" % (j .short_path ) for j in rjars ])
433- jvm_flags = " " .join ([ctx .expand_location (f , ctx .attr .data ) for f in jvm_flags ])
428+ def _runfiles_root (ctx ):
429+ return "${TEST_SRCDIR}/%s" % ctx .workspace_name
430+
431+ def _write_java_wrapper (ctx , args = "" , wrapper_preamble = "" ):
432+ """This creates a wrapper that sets up the correct path
433+ to stand in for the java command."""
434+
434435 java_path = str (ctx .attr ._java_runtime [java_common .JavaRuntimeInfo ].java_executable_runfiles_path )
435436 if _path_is_absolute (java_path ):
436437 javabin = java_path
437438 else :
439+ runfiles_root = _runfiles_root (ctx )
438440 javabin = "%s/%s" % (runfiles_root , java_path )
439441
440- template = ctx .attr ._java_stub_template .files .to_list ()[0 ]
441442
442443 exec_str = ""
443444 if wrapper_preamble == "" :
@@ -457,14 +458,21 @@ def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble
457458 args = args ,
458459 ),
459460 )
461+ return wrapper
460462
463+ def _write_executable (ctx , rjars , main_class , jvm_flags , wrapper ):
464+ template = ctx .attr ._java_stub_template .files .to_list ()[0 ]
465+ # RUNPATH is defined here:
466+ # https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
467+ classpath = ":" .join (["${RUNPATH}%s" % (j .short_path ) for j in rjars ])
468+ jvm_flags = " " .join ([ctx .expand_location (f , ctx .attr .data ) for f in jvm_flags ])
461469 ctx .template_action (
462470 template = template ,
463471 output = ctx .outputs .executable ,
464472 substitutions = {
465473 "%classpath%" : classpath ,
466474 "%java_start_class%" : main_class ,
467- "%javabin%" : "JAVABIN=%s/%s" % (runfiles_root , wrapper .short_path ),
475+ "%javabin%" : "JAVABIN=%s/%s" % (_runfiles_root ( ctx ) , wrapper .short_path ),
468476 "%jvm_flags%" : jvm_flags ,
469477 "%needs_runfiles%" : "" ,
470478 "%runfiles_manifest_only%" : "" ,
@@ -739,17 +747,16 @@ def _scala_macro_library_impl(ctx):
739747 return _lib (ctx , False ) # don't build the ijar for macros
740748
741749# Common code shared by all scala binary implementations.
742- def _scala_binary_common (ctx , cjars , rjars , transitive_compile_time_jars , jars2labels , implicit_junit_deps_needed_for_java_compilation = []):
750+ def _scala_binary_common (ctx , cjars , rjars , transitive_compile_time_jars , jars2labels , java_wrapper , implicit_junit_deps_needed_for_java_compilation = []):
743751 write_manifest (ctx )
744752 outputs = _compile_or_empty (ctx , cjars , [], False , transitive_compile_time_jars , jars2labels , implicit_junit_deps_needed_for_java_compilation ) # no need to build an ijar for an executable
745753 rjars += outputs .full_jars
746754
747- _build_deployable (ctx , list (rjars ))
748-
749- java_wrapper = ctx .new_file (ctx .label .name + "_wrapper.sh" )
755+ rjars_list = list (rjars )
756+ _build_deployable (ctx , rjars_list )
750757
751758 runfiles = ctx .runfiles (
752- files = list ( rjars ) + [ctx .outputs .executable , java_wrapper ] + ctx .files ._java_runtime ,
759+ files = rjars_list + [ctx .outputs .executable , java_wrapper ] + ctx .files ._java_runtime ,
753760 collect_data = True )
754761
755762 rule_outputs = struct (
@@ -781,12 +788,14 @@ def _scala_binary_impl(ctx):
781788 jars = _collect_jars_from_common_ctx (ctx )
782789 (cjars , transitive_rjars ) = (jars .compile_jars , jars .transitive_runtime_jars )
783790
784- out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels )
785- _write_launcher (
791+ wrapper = _write_java_wrapper (ctx , "" , "" )
792+ out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels , wrapper )
793+ _write_executable (
786794 ctx = ctx ,
787795 rjars = out .transitive_rjars ,
788796 main_class = ctx .attr .main_class ,
789797 jvm_flags = ctx .attr .jvm_flags ,
798+ wrapper = wrapper
790799 )
791800 return out
792801
@@ -795,15 +804,8 @@ def _scala_repl_impl(ctx):
795804 jars = _collect_jars_from_common_ctx (ctx , extra_runtime_deps = [ctx .attr ._scalacompiler ])
796805 (cjars , transitive_rjars ) = (jars .compile_jars , jars .transitive_runtime_jars )
797806
798- out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels )
799807 args = " " .join (ctx .attr .scalacopts )
800- _write_launcher (
801- ctx = ctx ,
802- rjars = out .transitive_rjars ,
803- main_class = "scala.tools.nsc.MainGenericRunner" ,
804- jvm_flags = ["-Dscala.usejavacp=true" ] + ctx .attr .jvm_flags ,
805- args = args ,
806- wrapper_preamble = """
808+ wrapper = _write_java_wrapper (ctx , args , wrapper_preamble = """
807809# save stty like in bin/scala
808810saved_stty=$(stty -g 2>/dev/null)
809811if [[ ! $? ]]; then
@@ -816,7 +818,15 @@ function finish() {
816818 fi
817819}
818820trap finish EXIT
819- """ ,
821+ """ )
822+
823+ out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels , wrapper )
824+ _write_executable (
825+ ctx = ctx ,
826+ rjars = out .transitive_rjars ,
827+ main_class = "scala.tools.nsc.MainGenericRunner" ,
828+ jvm_flags = ["-Dscala.usejavacp=true" ] + ctx .attr .jvm_flags ,
829+ wrapper = wrapper
820830 )
821831
822832 return out
@@ -857,14 +867,15 @@ def _scala_test_impl(ctx):
857867 _scala_test_flags (ctx ),
858868 "-C io.bazel.rules.scala.JUnitXmlReporter " ,
859869 ])
860- out = _scala_binary_common (ctx , cjars , transitive_rjars , transitive_compile_jars , jars_to_labels )
861870 # main_class almost has to be "org.scalatest.tools.Runner" due to args....
862- _write_launcher (
871+ wrapper = _write_java_wrapper (ctx , args , "" )
872+ out = _scala_binary_common (ctx , cjars , transitive_rjars , transitive_compile_jars , jars_to_labels , wrapper )
873+ _write_executable (
863874 ctx = ctx ,
864875 rjars = out .transitive_rjars ,
865876 main_class = ctx .attr .main_class ,
866877 jvm_flags = ctx .attr .jvm_flags ,
867- args = args ,
878+ wrapper = wrapper
868879 )
869880 return out
870881
@@ -890,14 +901,16 @@ def _scala_junit_test_impl(ctx):
890901 (cjars , transitive_rjars ) = (jars .compile_jars , jars .transitive_runtime_jars )
891902 implicit_junit_deps_needed_for_java_compilation = [ctx .attr ._junit , ctx .attr ._hamcrest ]
892903
893- out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels , implicit_junit_deps_needed_for_java_compilation )
904+ wrapper = _write_java_wrapper (ctx , "" , "" )
905+ out = _scala_binary_common (ctx , cjars , transitive_rjars , jars .transitive_compile_jars , jars .jars2labels , wrapper , implicit_junit_deps_needed_for_java_compilation )
894906 test_suite = _gen_test_suite_flags_based_on_prefixes_and_suffixes (ctx , out .scala .outputs .jars )
895907 launcherJvmFlags = ["-ea" , test_suite .archiveFlag , test_suite .prefixesFlag , test_suite .suffixesFlag , test_suite .printFlag , test_suite .testSuiteFlag ]
896- _write_launcher (
908+ _write_executable (
897909 ctx = ctx ,
898910 rjars = out .transitive_rjars ,
899911 main_class = "com.google.testing.junit.runner.BazelTestRunner" ,
900912 jvm_flags = launcherJvmFlags + ctx .attr .jvm_flags ,
913+ wrapper = wrapper
901914 )
902915
903916 return out
0 commit comments