Skip to content

Commit

Permalink
Proxy proto compiler and proto opts over proto_lang_toolchain rule.
Browse files Browse the repository at this point in the history
This is needed for the proto_common.generate_code function.

This change doesn't modify any of proto_lang_toolchain public attributes.

PiperOrigin-RevId: 437713619
  • Loading branch information
comius authored and copybara-github committed Mar 28, 2022
1 parent a229bed commit 66e29cd
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public ConfiguredTarget create(RuleContext ruleContext)
// targets. While this has the potential to use more memory than using a NestedSet,
// there are only a few `proto_lang_toolchain` targets in every build, so the impact
// on memory consumption should be neglectable.
providedProtoSources.build().toList()))
providedProtoSources.build().toList(),
ruleContext.getPrerequisite(":proto_compiler", FilesToRunProvider.class),
ruleContext.getConfiguration().getFragment(ProtoConfiguration.class).protocOpts()))
.setFilesToBuild(NestedSetBuilder.<Artifact>emptySet(STABLE_ORDER))
.addProvider(RunfilesProvider.simple(Runfiles.EMPTY))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.google.devtools.build.lib.collect.nestedset.NestedSet;
import com.google.devtools.build.lib.collect.nestedset.NestedSetBuilder;
import javax.annotation.Nullable;
import net.starlark.java.annot.StarlarkMethod;
import net.starlark.java.eval.StarlarkList;

// Note: AutoValue v1.4-rc1 has AutoValue.CopyAnnotations which makes it work with Starlark. No need
// to un-AutoValue this class to expose it to Starlark.
Expand All @@ -32,11 +34,25 @@
*/
@AutoValue
public abstract class ProtoLangToolchainProvider implements TransitiveInfoProvider {
@StarlarkMethod(
name = "out_replacement_format_flag",
doc = "Format string used when passing output to the plugin used by proto compiler.",
structField = true)
public abstract String outReplacementFormatFlag();

@StarlarkMethod(
name = "plugin_format_flag",
doc = "Format string used when passing plugin to proto compiler.",
structField = true,
allowReturnNones = true)
@Nullable
public abstract String pluginFormatFlag();

@StarlarkMethod(
name = "plugin",
doc = "Proto compiler plugin.",
structField = true,
allowReturnNones = true)
@Nullable
public abstract FilesToRunProvider pluginExecutable();

Expand All @@ -49,6 +65,19 @@ public abstract class ProtoLangToolchainProvider implements TransitiveInfoProvid
*/
public abstract ImmutableList<ProtoSource> providedProtoSources();

@StarlarkMethod(name = "proto_compiler", doc = "Proto compiler.", structField = true)
public abstract FilesToRunProvider protoc();

@StarlarkMethod(
name = "protoc_opts",
doc = "Options to pass to proto compiler.",
structField = true)
public StarlarkList<String> protocOptsForStarlark() {
return StarlarkList.immutableCopyOf(protocOpts());
}

public abstract ImmutableList<String> protocOpts();

/**
* This makes the blacklisted_protos member available in the provider. It can be removed after
* users are migrated and a sufficient time for Bazel rules to migrate has elapsed.
Expand All @@ -67,7 +96,9 @@ public static ProtoLangToolchainProvider create(
String pluginFormatFlag,
FilesToRunProvider pluginExecutable,
TransitiveInfoCollection runtime,
ImmutableList<ProtoSource> providedProtoSources) {
ImmutableList<ProtoSource> providedProtoSources,
FilesToRunProvider protoc,
ImmutableList<String> protocOpts) {
NestedSetBuilder<Artifact> blacklistedProtos = NestedSetBuilder.stableOrder();
for (ProtoSource protoSource : providedProtoSources) {
blacklistedProtos.add(protoSource.getOriginalSourceFile());
Expand All @@ -78,6 +109,8 @@ public static ProtoLangToolchainProvider create(
pluginExecutable,
runtime,
providedProtoSources,
protoc,
protocOpts,
blacklistedProtos.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,25 @@
import com.google.devtools.build.lib.analysis.RuleDefinition;
import com.google.devtools.build.lib.analysis.RuleDefinitionEnvironment;
import com.google.devtools.build.lib.analysis.config.ExecutionTransitionFactory;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.packages.Attribute;
import com.google.devtools.build.lib.packages.RuleClass;
import com.google.devtools.build.lib.packages.StarlarkProviderIdentifier;
import com.google.devtools.build.lib.packages.Type;

/** Implements {code proto_lang_toolchain}. */
public class ProtoLangToolchainRule implements RuleDefinition {
private static final Label DEFAULT_PROTO_COMPILER =
Label.parseAbsoluteUnchecked(ProtoConstants.DEFAULT_PROTOC_LABEL);
private static final Attribute.LabelLateBoundDefault<?> PROTO_COMPILER =
Attribute.LabelLateBoundDefault.fromTargetConfiguration(
ProtoConfiguration.class,
DEFAULT_PROTO_COMPILER,
(rule, attributes, protoConfig) ->
protoConfig.protoCompiler() != null
? protoConfig.protoCompiler()
: DEFAULT_PROTO_COMPILER);

@Override
public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment environment) {
return builder
Expand Down Expand Up @@ -78,6 +91,11 @@ public RuleClass build(RuleClass.Builder builder, RuleDefinitionEnvironment envi
attr("blacklisted_protos", LABEL_LIST)
.allowedFileTypes()
.mandatoryProviders(StarlarkProviderIdentifier.forKey(ProtoInfo.PROVIDER.getKey())))
.add(
attr(":proto_compiler", LABEL)
.cfg(ExecutionTransitionFactory.create())
.exec()
.value(PROTO_COMPILER))
.requiresConfigurationFragments(ProtoConfiguration.class)
.advertiseProvider(ProtoLangToolchainProvider.class)
.removeAttribute("data")
Expand Down
2 changes: 0 additions & 2 deletions src/test/java/com/google/devtools/build/lib/rules/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@ java_test(
name = "ProtoCompileActionBuilderTest",
srcs = ["ProtoCompileActionBuilderTest.java"],
deps = [
"//src/main/java/com/google/devtools/build/lib/actions",
"//src/main/java/com/google/devtools/build/lib/actions:artifacts",
"//src/main/java/com/google/devtools/build/lib/actions:localhost_capacity",
"//src/main/java/com/google/devtools/build/lib/analysis:actions/custom_command_line",
"//src/main/java/com/google/devtools/build/lib/analysis:analysis_cluster",
"//src/main/java/com/google/devtools/build/lib/analysis:transitive_info_collection",
"//src/main/java/com/google/devtools/build/lib/cmdline",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,18 @@ public void commandLine_basic() throws Exception {
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());
ProtoLangToolchainProvider toolchainWithPlugin =
ProtoLangToolchainProvider.create(
"--PLUGIN_pluginName_out=param3,param4:%s",
/* pluginFormatFlag= */ "--plugin=protoc-gen-PLUGIN_pluginName=%s",
plugin,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());
useConfiguration("--strict_proto_deps=OFF");

RuleContext ruleContext =
Expand Down Expand Up @@ -200,7 +204,9 @@ public void commandLine_strictDeps() throws Exception {
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());

RuleContext ruleContext =
getRuleContext(getConfiguredTarget("//foo:bar"), collectingAnalysisEnvironment);
Expand Down Expand Up @@ -236,7 +242,9 @@ public void commandLine_exports() throws Exception {
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());
useConfiguration("--strict_proto_deps=OFF");

RuleContext ruleContext =
Expand Down Expand Up @@ -307,7 +315,9 @@ public void outReplacementAreLazilyEvaluated() throws Exception {
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());

RuleContext ruleContext =
getRuleContext(getConfiguredTarget("//foo:bar"), collectingAnalysisEnvironment);
Expand Down Expand Up @@ -341,14 +351,18 @@ public void exceptionIfSameName() throws Exception {
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());
ProtoLangToolchainProvider toolchain2 =
ProtoLangToolchainProvider.create(
"dontcare=%s",
/* pluginFormatFlag= */ null,
/* pluginExecutable= */ null,
/* runtime= */ mock(TransitiveInfoCollection.class),
/* providedProtoSources= */ ImmutableList.of());
/* providedProtoSources= */ ImmutableList.of(),
/* protoc= */ FilesToRunProvider.EMPTY,
/* protocOpts= */ ImmutableList.of());

RuleContext ruleContext =
getRuleContext(getConfiguredTarget("//foo:bar"), collectingAnalysisEnvironment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class ProtoLangToolchainTest extends BuildViewTestCase {
public void setUp() throws Exception {
MockProtoSupport.setupWorkspace(scratch);
MockProtoSupport.setup(mockToolsConfig);
useConfiguration("--protocopt=--myflag");
invalidatePackages();
}

Expand All @@ -54,6 +55,11 @@ private void validateProtoLangToolchain(ProtoLangToolchainProvider toolchain) th
"third_party/x/metadata.proto",
"third_party/x/descriptor.proto",
"third_party/x/any.proto");

assertThat(toolchain.protocOpts()).containsExactly("--myflag");
Label protoc = Label.parseAbsoluteUnchecked(ProtoConstants.DEFAULT_PROTOC_LABEL);
assertThat(toolchain.protoc().getExecutable().prettyPrint())
.isEqualTo(protoc.toPathFragment().getPathString());
}

@Test
Expand Down

0 comments on commit 66e29cd

Please sign in to comment.