Skip to content

Commit

Permalink
[8.1.0] Use digest function matching the checksum in gRPC remote down…
Browse files Browse the repository at this point in the history
…loader (#25225)

Fixes https://bazelbuild.slack.com/archives/CA31HN1T3/p1738763759125489

Closes #25206.

PiperOrigin-RevId: 724267755
Change-Id: Ia23bdae310231bd0ee5763311b948f3465aa8ed0

Commit
ef45e02

Co-authored-by: Fabian Meumertzheim <fabian@meumertzhe.im>
  • Loading branch information
bazel-io and fmeum authored Feb 7, 2025
1 parent aa4531d commit 14219c4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public class GrpcRemoteDownloader implements AutoCloseable, Downloader {
private final Optional<CallCredentials> credentials;
private final RemoteRetrier retrier;
private final RemoteCacheClient cacheClient;
private final DigestFunction.Value digestFunction;
private final DigestFunction.Value defaultDigestFunction;
private final RemoteOptions options;
private final boolean verboseFailures;
@Nullable private final Downloader fallbackDownloader;
Expand Down Expand Up @@ -100,7 +100,7 @@ public GrpcRemoteDownloader(
Optional<CallCredentials> credentials,
RemoteRetrier retrier,
RemoteCacheClient cacheClient,
DigestFunction.Value digestFunction,
DigestFunction.Value defaultDigestFunction,
RemoteOptions options,
boolean verboseFailures,
@Nullable Downloader fallbackDownloader) {
Expand All @@ -110,7 +110,7 @@ public GrpcRemoteDownloader(
this.credentials = credentials;
this.retrier = retrier;
this.cacheClient = cacheClient;
this.digestFunction = digestFunction;
this.defaultDigestFunction = defaultDigestFunction;
this.options = options;
this.verboseFailures = verboseFailures;
this.fallbackDownloader = fallbackDownloader;
Expand Down Expand Up @@ -149,7 +149,7 @@ public void download(
urls,
checksum,
canonicalId,
digestFunction,
defaultDigestFunction,
headers,
credentials);
try {
Expand Down Expand Up @@ -200,14 +200,12 @@ static FetchBlobRequest newFetchBlobRequest(
List<URL> urls,
Optional<Checksum> checksum,
String canonicalId,
DigestFunction.Value digestFunction,
DigestFunction.Value defaultDigestFunction,
Map<String, List<String>> headers,
Credentials credentials)
throws IOException {
FetchBlobRequest.Builder requestBuilder =
FetchBlobRequest.newBuilder()
.setInstanceName(instanceName)
.setDigestFunction(digestFunction);
FetchBlobRequest.newBuilder().setInstanceName(instanceName);
for (int i = 0; i < urls.size(); i++) {
var url = urls.get(i);
requestBuilder.addUris(url.toString());
Expand All @@ -233,12 +231,21 @@ static FetchBlobRequest newFetchBlobRequest(
}

if (checksum.isPresent()) {
requestBuilder.setDigestFunction(
switch (checksum.get().getKeyType()) {
case SHA1 -> DigestFunction.Value.SHA1;
case SHA256 -> DigestFunction.Value.SHA256;
case SHA384 -> DigestFunction.Value.SHA384;
case SHA512 -> DigestFunction.Value.SHA512;
case BLAKE3 -> DigestFunction.Value.BLAKE3;
});
requestBuilder.addQualifiers(
Qualifier.newBuilder()
.setName(QUALIFIER_CHECKSUM_SRI)
.setValue(checksum.get().toSubresourceIntegrity())
.build());
} else {
requestBuilder.setDigestFunction(defaultDigestFunction);
// If no checksum is provided, never accept cached content.
// Timestamp is offset by an hour to account for clock skew.
requestBuilder.setOldestContentAccepted(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import build.bazel.remote.asset.v1.FetchGrpc.FetchImplBase;
import build.bazel.remote.asset.v1.Qualifier;
import build.bazel.remote.execution.v2.Digest;
import build.bazel.remote.execution.v2.DigestFunction;
import build.bazel.remote.execution.v2.RequestMetadata;
import build.bazel.remote.execution.v2.ServerCapabilities;
import com.google.auth.Credentials;
Expand Down Expand Up @@ -96,8 +97,10 @@ public class GrpcRemoteDownloaderTest {

private static final ManualClock clock = new ManualClock();

// Use an unusual default to verify that the hash function used to generate a given Checksum is
// propagated correctly.
private static final DigestUtil DIGEST_UTIL =
new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256);
new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA1);

private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry();
private final String fakeServerName = "fake server for " + getClass();
Expand Down Expand Up @@ -277,7 +280,8 @@ public void fetchBlob(
@Test
public void testPropagateChecksum() throws Exception {
final byte[] content = "example content".getBytes(UTF_8);
final Digest contentDigest = DIGEST_UTIL.compute(content);
final DigestUtil digestUtil = new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256);
final Digest contentDigest = digestUtil.compute(content);

serviceRegistry.addService(
new FetchImplBase() {
Expand All @@ -287,7 +291,7 @@ public void fetchBlob(
assertThat(request)
.isEqualTo(
FetchBlobRequest.newBuilder()
.setDigestFunction(DIGEST_UTIL.getDigestFunction())
.setDigestFunction(digestUtil.getDigestFunction())
.addUris("http://example.com/content.txt")
.addQualifiers(
Qualifier.newBuilder()
Expand Down Expand Up @@ -316,7 +320,8 @@ public void fetchBlob(
@Test
public void testRejectChecksumMismatch() throws Exception {
final byte[] content = "example content".getBytes(UTF_8);
final Digest contentDigest = DIGEST_UTIL.compute(content);
final DigestUtil digestUtil = new DigestUtil(SyscallCache.NO_CACHE, DigestHashFunction.SHA256);
final Digest contentDigest = digestUtil.compute(content);

serviceRegistry.addService(
new FetchImplBase() {
Expand All @@ -326,7 +331,7 @@ public void fetchBlob(
assertThat(request)
.isEqualTo(
FetchBlobRequest.newBuilder()
.setDigestFunction(DIGEST_UTIL.getDigestFunction())
.setDigestFunction(digestUtil.getDigestFunction())
.addUris("http://example.com/content.txt")
.addQualifiers(
Qualifier.newBuilder()
Expand Down Expand Up @@ -355,7 +360,7 @@ public void fetchBlob(
Optional.of(Checksum.fromString(KeyType.SHA256, contentDigest.getHash()))));

assertThat(e).hasMessageThat().contains(contentDigest.getHash());
assertThat(e).hasMessageThat().contains(DIGEST_UTIL.computeAsUtf8("wrong content").getHash());
assertThat(e).hasMessageThat().contains(digestUtil.computeAsUtf8("wrong content").getHash());
}

@Test
Expand All @@ -382,7 +387,7 @@ public void testFetchBlobRequest() throws Exception {
.isEqualTo(
FetchBlobRequest.newBuilder()
.setInstanceName("instance name")
.setDigestFunction(DIGEST_UTIL.getDigestFunction())
.setDigestFunction(DigestFunction.Value.SHA256)
.addUris("http://example.com/a")
.addUris("http://example.com/b")
.addUris("file:/not/limited/to/http")
Expand Down Expand Up @@ -431,7 +436,7 @@ public void testFetchBlobRequest_withCredentialsPropagation() throws Exception {
.isEqualTo(
FetchBlobRequest.newBuilder()
.setInstanceName("instance name")
.setDigestFunction(DIGEST_UTIL.getDigestFunction())
.setDigestFunction(DigestFunction.Value.SHA256)
.addUris("http://example.com/a")
.addQualifiers(
Qualifier.newBuilder()
Expand Down Expand Up @@ -474,7 +479,7 @@ public void testFetchBlobRequest_withoutCredentialsPropagation() throws Exceptio
.isEqualTo(
FetchBlobRequest.newBuilder()
.setInstanceName("instance name")
.setDigestFunction(DIGEST_UTIL.getDigestFunction())
.setDigestFunction(DigestFunction.Value.SHA256)
.addUris("http://example.com/a")
.addQualifiers(
Qualifier.newBuilder()
Expand Down

0 comments on commit 14219c4

Please sign in to comment.