Skip to content

Commit 57972be

Browse files
authored
Refactor write_launcher (#449)
* Refactor write_launcher * move a comment to the appropriate place
1 parent 0be293f commit 57972be

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

scala/scala.bzl

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,20 @@ def _path_is_absolute(path):
425425

426426
return False
427427

428-
def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble=""):
429-
runfiles_root = "${TEST_SRCDIR}/%s" % ctx.workspace_name
430-
# RUNPATH is defined here:
431-
# https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
432-
classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars])
433-
jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags])
428+
def _runfiles_root(ctx):
429+
return "${TEST_SRCDIR}/%s" % ctx.workspace_name
430+
431+
def _write_java_wrapper(ctx, args="", wrapper_preamble=""):
432+
"""This creates a wrapper that sets up the correct path
433+
to stand in for the java command."""
434+
434435
java_path = str(ctx.attr._java_runtime[java_common.JavaRuntimeInfo].java_executable_runfiles_path)
435436
if _path_is_absolute(java_path):
436437
javabin = java_path
437438
else:
439+
runfiles_root = _runfiles_root(ctx)
438440
javabin = "%s/%s" % (runfiles_root, java_path)
439441

440-
template = ctx.attr._java_stub_template.files.to_list()[0]
441442

442443
exec_str = ""
443444
if wrapper_preamble == "":
@@ -457,14 +458,21 @@ def _write_launcher(ctx, rjars, main_class, jvm_flags, args="", wrapper_preamble
457458
args=args,
458459
),
459460
)
461+
return wrapper
460462

463+
def _write_executable(ctx, rjars, main_class, jvm_flags, wrapper):
464+
template = ctx.attr._java_stub_template.files.to_list()[0]
465+
# RUNPATH is defined here:
466+
# https://github.com/bazelbuild/bazel/blob/0.4.5/src/main/java/com/google/devtools/build/lib/bazel/rules/java/java_stub_template.txt#L227
467+
classpath = ":".join(["${RUNPATH}%s" % (j.short_path) for j in rjars])
468+
jvm_flags = " ".join([ctx.expand_location(f, ctx.attr.data) for f in jvm_flags])
461469
ctx.template_action(
462470
template = template,
463471
output = ctx.outputs.executable,
464472
substitutions = {
465473
"%classpath%": classpath,
466474
"%java_start_class%": main_class,
467-
"%javabin%": "JAVABIN=%s/%s" % (runfiles_root, wrapper.short_path),
475+
"%javabin%": "JAVABIN=%s/%s" % (_runfiles_root(ctx), wrapper.short_path),
468476
"%jvm_flags%": jvm_flags,
469477
"%needs_runfiles%": "",
470478
"%runfiles_manifest_only%": "",
@@ -739,17 +747,16 @@ def _scala_macro_library_impl(ctx):
739747
return _lib(ctx, False) # don't build the ijar for macros
740748

741749
# Common code shared by all scala binary implementations.
742-
def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation = []):
750+
def _scala_binary_common(ctx, cjars, rjars, transitive_compile_time_jars, jars2labels, java_wrapper, implicit_junit_deps_needed_for_java_compilation = []):
743751
write_manifest(ctx)
744752
outputs = _compile_or_empty(ctx, cjars, [], False, transitive_compile_time_jars, jars2labels, implicit_junit_deps_needed_for_java_compilation) # no need to build an ijar for an executable
745753
rjars += outputs.full_jars
746754

747-
_build_deployable(ctx, list(rjars))
748-
749-
java_wrapper = ctx.new_file(ctx.label.name + "_wrapper.sh")
755+
rjars_list = list(rjars)
756+
_build_deployable(ctx, rjars_list)
750757

751758
runfiles = ctx.runfiles(
752-
files = list(rjars) + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime,
759+
files = rjars_list + [ctx.outputs.executable, java_wrapper] + ctx.files._java_runtime,
753760
collect_data = True)
754761

755762
rule_outputs = struct(
@@ -781,12 +788,14 @@ def _scala_binary_impl(ctx):
781788
jars = _collect_jars_from_common_ctx(ctx)
782789
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)
783790

784-
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels)
785-
_write_launcher(
791+
wrapper = _write_java_wrapper(ctx, "", "")
792+
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper)
793+
_write_executable(
786794
ctx = ctx,
787795
rjars = out.transitive_rjars,
788796
main_class = ctx.attr.main_class,
789797
jvm_flags = ctx.attr.jvm_flags,
798+
wrapper = wrapper
790799
)
791800
return out
792801

@@ -795,15 +804,8 @@ def _scala_repl_impl(ctx):
795804
jars = _collect_jars_from_common_ctx(ctx, extra_runtime_deps = [ctx.attr._scalacompiler])
796805
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)
797806

798-
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels)
799807
args = " ".join(ctx.attr.scalacopts)
800-
_write_launcher(
801-
ctx = ctx,
802-
rjars = out.transitive_rjars,
803-
main_class = "scala.tools.nsc.MainGenericRunner",
804-
jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags,
805-
args = args,
806-
wrapper_preamble = """
808+
wrapper = _write_java_wrapper(ctx, args, wrapper_preamble = """
807809
# save stty like in bin/scala
808810
saved_stty=$(stty -g 2>/dev/null)
809811
if [[ ! $? ]]; then
@@ -816,7 +818,15 @@ function finish() {
816818
fi
817819
}
818820
trap finish EXIT
819-
""",
821+
""")
822+
823+
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper)
824+
_write_executable(
825+
ctx = ctx,
826+
rjars = out.transitive_rjars,
827+
main_class = "scala.tools.nsc.MainGenericRunner",
828+
jvm_flags = ["-Dscala.usejavacp=true"] + ctx.attr.jvm_flags,
829+
wrapper = wrapper
820830
)
821831

822832
return out
@@ -857,14 +867,15 @@ def _scala_test_impl(ctx):
857867
_scala_test_flags(ctx),
858868
"-C io.bazel.rules.scala.JUnitXmlReporter ",
859869
])
860-
out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels)
861870
# main_class almost has to be "org.scalatest.tools.Runner" due to args....
862-
_write_launcher(
871+
wrapper = _write_java_wrapper(ctx, args, "")
872+
out = _scala_binary_common(ctx, cjars, transitive_rjars, transitive_compile_jars, jars_to_labels, wrapper)
873+
_write_executable(
863874
ctx = ctx,
864875
rjars = out.transitive_rjars,
865876
main_class = ctx.attr.main_class,
866877
jvm_flags = ctx.attr.jvm_flags,
867-
args = args,
878+
wrapper = wrapper
868879
)
869880
return out
870881

@@ -890,14 +901,16 @@ def _scala_junit_test_impl(ctx):
890901
(cjars, transitive_rjars) = (jars.compile_jars, jars.transitive_runtime_jars)
891902
implicit_junit_deps_needed_for_java_compilation = [ctx.attr._junit, ctx.attr._hamcrest]
892903

893-
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, implicit_junit_deps_needed_for_java_compilation)
904+
wrapper = _write_java_wrapper(ctx, "", "")
905+
out = _scala_binary_common(ctx, cjars, transitive_rjars, jars.transitive_compile_jars, jars.jars2labels, wrapper, implicit_junit_deps_needed_for_java_compilation)
894906
test_suite = _gen_test_suite_flags_based_on_prefixes_and_suffixes(ctx, out.scala.outputs.jars)
895907
launcherJvmFlags = ["-ea", test_suite.archiveFlag, test_suite.prefixesFlag, test_suite.suffixesFlag, test_suite.printFlag, test_suite.testSuiteFlag]
896-
_write_launcher(
908+
_write_executable(
897909
ctx = ctx,
898910
rjars = out.transitive_rjars,
899911
main_class = "com.google.testing.junit.runner.BazelTestRunner",
900912
jvm_flags = launcherJvmFlags + ctx.attr.jvm_flags,
913+
wrapper = wrapper
901914
)
902915

903916
return out

0 commit comments

Comments
 (0)