Skip to content

Commit

Permalink
Implement proto_common.create_proto_info in Starlark.
Browse files Browse the repository at this point in the history
There is some confusion in naming of the fields in ProtoInfo, for example:
```
  ImmutableList<ProtoSource> directSources; // in Starlark .proto_sources (added)
  ImmutableList<Artifact> directProtoSources; // in Starlark .sources (was there before)
```

We need to keep old names in Starlark and create some new. The naming opportunity for the new Starlark names clashed with what was already used natively.

PiperOrigin-RevId: 410502067
  • Loading branch information
comius authored and copybara-github committed Nov 17, 2021
1 parent 2a44dbd commit a6c8818
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@

package com.google.devtools.build.lib.rules.proto;

import com.google.devtools.build.lib.analysis.starlark.StarlarkRuleContext;
import com.google.devtools.build.lib.cmdline.Label;
import com.google.devtools.build.lib.packages.BazelModuleContext;
import com.google.devtools.build.lib.actions.Artifact;
import com.google.devtools.build.lib.collect.nestedset.Depset;
import com.google.devtools.build.lib.starlarkbuildapi.proto.ProtoCommonApi;
import javax.annotation.Nullable;
import com.google.devtools.build.lib.vfs.PathFragment;
import net.starlark.java.annot.Param;
import net.starlark.java.annot.StarlarkMethod;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Module;
import net.starlark.java.eval.Starlark;
import net.starlark.java.eval.StarlarkList;
import net.starlark.java.eval.StarlarkThread;

/** Protocol buffers support for Starlark. */
Expand All @@ -33,26 +31,66 @@ public class BazelProtoCommon implements ProtoCommonApi {
protected BazelProtoCommon() {}

@StarlarkMethod(
name = "create_proto_info",
name = "ProtoSource",
documented = false,
parameters = {@Param(name = "ctx", doc = "The rule context")},
useStarlarkThread = true,
allowReturnNones = true)
@Nullable
public ProtoInfo createProtoInfo(StarlarkRuleContext ruleContext, StarlarkThread thread)
parameters = {
@Param(name = "source_file", doc = "The proto file."),
@Param(name = "original_source_file", doc = "Original proto file."),
@Param(name = "proto_path", doc = "Path to proto file."),
},
useStarlarkThread = true)
public ProtoSource protoSource(
Artifact sourceFile, Artifact originalSourceFile, String sourceRoot, StarlarkThread thread)
throws EvalException {
Label label =
((BazelModuleContext) Module.ofInnermostEnclosingStarlarkFunction(thread).getClientData())
.label();
if (!label.getPackageIdentifier().getRepository().toString().equals("@_builtins")) {
throw Starlark.errorf("Rule in '%s' cannot use private API", label.getPackageName());
}
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return new ProtoSource(sourceFile, originalSourceFile, PathFragment.create(sourceRoot));
}

return ProtoCommon.createProtoInfo(
ruleContext.getRuleContext(),
ruleContext
.getRuleContext()
.getFragment(ProtoConfiguration.class)
.generatedProtosInVirtualImports());
@StarlarkMethod(
name = "ProtoInfo",
documented = false,
parameters = {
@Param(name = "direct_sources", doc = "Direct sources."),
@Param(name = "proto_path", doc = "Proto path."),
@Param(name = "transitive_sources", doc = "Transitive sources."),
@Param(name = "transitive_proto_sources", doc = "Transitive proto sources."),
@Param(name = "transitive_proto_path", doc = "Transitive proto path."),
@Param(name = "check_deps_sources", doc = "Check deps sources."),
@Param(name = "direct_descriptor_set", doc = "Direct descriptor set."),
@Param(name = "transitive_descriptor_set", doc = "Transitive descriptor sets."),
@Param(name = "exported_sources", doc = "Exported sources"),
@Param(name = "strict_importable_sources", doc = "Strict importable sources."),
@Param(name = "public_import_protos", doc = "Public import protos."),
},
useStarlarkThread = true)
@SuppressWarnings("unchecked")
public ProtoInfo protoInfo(
StarlarkList<? extends ProtoSource> directSources,
String directProtoSourceRoot,
Depset transitiveProtoSources,
Depset transitiveSources,
Depset transitiveProtoSourceRoots,
Depset strictImportableProtoSourcesForDependents,
Artifact directDescriptorSet,
Depset transitiveDescriptorSets,
Depset exportedSources,
Depset strictImportableSources,
Depset publicImportSources,
StarlarkThread thread)
throws EvalException {
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return new ProtoInfo(
((StarlarkList<ProtoSource>) directSources).getImmutableList(),
PathFragment.create(directProtoSourceRoot),
Depset.cast(transitiveSources, ProtoSource.class, "transitive_sources"),
Depset.cast(transitiveProtoSources, Artifact.class, "transitive_proto_sources"),
Depset.cast(transitiveProtoSourceRoots, String.class, "transitive_proto_path"),
Depset.cast(
strictImportableProtoSourcesForDependents, Artifact.class, "check_deps_sources"),
directDescriptorSet,
Depset.cast(transitiveDescriptorSets, Artifact.class, "transitive_descriptor_set"),
Depset.cast(exportedSources, ProtoSource.class, "exported_sources"),
Depset.cast(strictImportableSources, ProtoSource.class, "strict_importable_sources"),
Depset.cast(publicImportSources, ProtoSource.class, "public_import_protos"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,16 @@ public boolean strictPublicImports() {
return options.experimentalJavaProtoAddAllowedPublicImports;
}

@StarlarkMethod(
name = "generated_protos_in_virtual_imports",
useStarlarkThread = true,
documented = false)
public boolean generatedProtosInVirtualImportsForStarlark(StarlarkThread thread)
throws EvalException {
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return generatedProtosInVirtualImports();
}

public boolean generatedProtosInVirtualImports() {
return options.generatedProtosInVirtualImports;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ public NestedSet<Artifact> getTransitiveDescriptorSets() {
return transitiveDescriptorSets;
}

@Override
public Depset getExportedSourcesForStarlark(StarlarkThread thread) throws EvalException {
ProtoCommon.checkPrivateStarlarkificationAllowlist(thread);
return Depset.of(ProtoSource.TYPE, getExportedSources());
}

/**
* Returns a set of {@code .proto} sources that may be imported by {@code proto_library} targets
* directly depending on this {@code ProtoInfo}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,9 @@ public PathFragment getImportPath() {
public String toString() {
return "ProtoSource('" + getImportPath() + "')";
}

@Override
public boolean isImmutable() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,7 @@ interface ProtoInfoProviderApi extends ProviderApi {

@StarlarkMethod(name = "transitive_proto_sources", documented = false, useStarlarkThread = true)
Depset getTransitiveSourcesForStarlark(StarlarkThread thread) throws EvalException;

@StarlarkMethod(name = "exported_sources", documented = false, useStarlarkThread = true)
Depset getExportedSourcesForStarlark(StarlarkThread thread) throws EvalException;
}
118 changes: 118 additions & 0 deletions src/main/starlark/builtins_bzl/common/proto/proto_common.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,125 @@ Definition of proto_common module.
"""

load(":common/proto/proto_semantics.bzl", "semantics")
load(":common/paths.bzl", "paths")

ProtoInfo = _builtins.toplevel.ProtoInfo
native_proto_common = _builtins.toplevel.proto_common

def _join(*path):
return "/".join([p for p in path if p != ""])

def _create_proto_info(ctx):
srcs = ctx.files.srcs
deps = [dep[ProtoInfo] for dep in ctx.attr.deps]
exports = [dep[ProtoInfo] for dep in ctx.attr.exports]

import_prefix = ctx.attr.import_prefix if hasattr(ctx.attr, "import_prefix") else ""
if not paths.is_normalized(import_prefix):
fail("should be normalized (without uplevel references or '.' path segments)", attr = "import_prefix")

strip_import_prefix = ctx.attr.strip_import_prefix
if not paths.is_normalized(strip_import_prefix):
fail("should be normalized (without uplevel references or '.' path segments)", attr = "strip_import_prefix")
if strip_import_prefix.startswith("/"):
strip_import_prefix = strip_import_prefix[1:]
elif strip_import_prefix != "DO_NOT_STRIP": # Relative to current package
strip_import_prefix = _join(ctx.label.package, strip_import_prefix)
else:
strip_import_prefix = ""

has_generated_sources = False
if ctx.fragments.proto.generated_protos_in_virtual_imports():
has_generated_sources = any([not src.is_source for src in srcs])

direct_sources = []
if import_prefix != "" or strip_import_prefix != "" or has_generated_sources:
# Use virtual source roots
if paths.is_absolute(import_prefix):
fail("should be a relative path", attr = "import_prefix")

virtual_imports = _join("_virtual_imports", ctx.label.name)
if ctx.label.workspace_name == "" or ctx.label.workspace_root.startswith(".."): # siblingRepositoryLayout
proto_path = _join(ctx.genfiles_dir.path, ctx.label.package, virtual_imports)
else:
proto_path = _join(ctx.genfiles_dir.path, ctx.label.workspace_root, ctx.label.package, virtual_imports)

for src in srcs:
if ctx.label.workspace_name == "":
repository_relative_path = src.short_path
else:
repository_relative_path = paths.relativize(src.short_path, "../" + ctx.label.workspace_name)

if not repository_relative_path.startswith(strip_import_prefix):
fail(".proto file '%s' is not under the specified strip prefix '%s'" %
(src.short_path, strip_import_prefix))
import_path = repository_relative_path[len(strip_import_prefix):]

virtual_src = ctx.actions.declare_file(_join(virtual_imports, import_prefix, import_path))
ctx.actions.symlink(
output = virtual_src,
target_file = src,
progress_message = "Symlinking virtual .proto sources for %{label}",
)
direct_sources.append(native_proto_common.ProtoSource(virtual_src, src, proto_path))

else:
# No virtual source roots
proto_path = "."
for src in srcs:
direct_sources.append(native_proto_common.ProtoSource(src, src, ctx.label.workspace_root + src.root.path))

# Construct ProtoInfo
transitive_proto_sources = depset(
direct = direct_sources,
transitive = [dep.transitive_proto_sources() for dep in deps],
order = "preorder",
)
transitive_sources = depset(
direct = [src.source_file() for src in direct_sources],
transitive = [dep.transitive_sources for dep in deps],
order = "preorder",
)
transitive_proto_path = depset(
direct = [proto_path],
transitive = [dep.transitive_proto_path for dep in deps],
)
if direct_sources:
check_deps_sources = depset(direct = [src.source_file() for src in direct_sources])
else:
check_deps_sources = depset(transitive = [dep.check_deps_sources for dep in deps])

direct_descriptor_set = ctx.actions.declare_file(ctx.label.name + "-descriptor-set.proto.bin")
transitive_descriptor_sets = depset(
direct = [direct_descriptor_set],
transitive = [dep.transitive_descriptor_sets for dep in deps],
)

# Layering checks.
if direct_sources:
exported_sources = depset(direct = direct_sources)
strict_importable_sources = depset(
direct = direct_sources,
transitive = [dep.exported_sources() for dep in deps],
)
else:
exported_sources = depset(transitive = [dep.exported_sources() for dep in deps])
strict_importable_sources = depset()
public_import_protos = depset(transitive = [export.exported_sources() for export in exports])

return native_proto_common.ProtoInfo(
direct_sources,
proto_path,
transitive_sources,
transitive_proto_sources,
transitive_proto_path,
check_deps_sources,
direct_descriptor_set,
transitive_descriptor_sets,
exported_sources,
strict_importable_sources,
public_import_protos,
)

def _write_descriptor_set(ctx, proto_info):
output = proto_info.direct_descriptor_set
Expand Down Expand Up @@ -83,5 +200,6 @@ def _ExpandImportArgsFn(proto_source):
return "-I%s=%s" % (proto_source.import_path(), proto_source.source_file().path)

proto_common = struct(
create_proto_info = _create_proto_info,
write_descriptor_set = _write_descriptor_set,
)
5 changes: 2 additions & 3 deletions src/main/starlark/builtins_bzl/common/proto/proto_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ load(":common/proto/proto_semantics.bzl", "semantics")
load(":common/proto/proto_common.bzl", "proto_common")

ProtoInfo = _builtins.toplevel.ProtoInfo
native_proto_common = _builtins.toplevel.proto_common

def _check_srcs_package(target_package, srcs):
"""Makes sure the given srcs live in the given package."""
Expand All @@ -33,7 +32,7 @@ def _proto_library_impl(ctx):

_check_srcs_package(ctx.label.package, ctx.attr.srcs)

proto_info = native_proto_common.create_proto_info(ctx)
proto_info = proto_common.create_proto_info(ctx)

proto_common.write_descriptor_set(ctx, proto_info)

Expand Down Expand Up @@ -64,7 +63,7 @@ proto_library = rule(
"exports": attr.label_list(
providers = [ProtoInfo],
),
"strip_import_prefix": attr.string(),
"strip_import_prefix": attr.string(default = "DO_NOT_STRIP"),
"data": attr.label_list(
allow_files = True,
flags = ["SKIP_CONSTRAINTS_OVERRIDE"],
Expand Down

0 comments on commit a6c8818

Please sign in to comment.