From 50274a9f714616d4735a560db7f617e53fb8d01b Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Mon, 15 Nov 2021 20:30:22 +0800 Subject: [PATCH] [5.x] Remote: Add support for compression on gRPC cache (#14277) * Add patch files for zstd-jni Partial commit for third_party/*, see #14203. Closes https://github.com/bazelbuild/bazel/pull/14203 Signed-off-by: Yun Peng * Remote: Add support for compression on gRPC cache Add support for compressed transfers from/to gRPC remote caches with flag --experimental_remote_cache_compression. Fixes #13344. Closes #14041. PiperOrigin-RevId: 409328001 Co-authored-by: Alessandro Patti --- BUILD | 1 + WORKSPACE | 8 + distdir_deps.bzl | 15 + .../google/devtools/build/lib/remote/BUILD | 3 + .../build/lib/remote/ByteStreamUploader.java | 58 ++-- .../devtools/build/lib/remote/Chunker.java | 107 +++++-- .../build/lib/remote/GrpcCacheClient.java | 40 ++- .../lib/remote/RemoteServerCapabilities.java | 8 + .../lib/remote/options/RemoteOptions.java | 8 + .../devtools/build/lib/remote/zstd/BUILD | 24 ++ .../zstd/ZstdCompressingInputStream.java | 104 +++++++ .../zstd/ZstdDecompressingOutputStream.java | 63 ++++ .../google/devtools/build/lib/remote/BUILD | 2 + .../lib/remote/ByteStreamUploaderTest.java | 275 +++++++++++++++++- .../build/lib/remote/ChunkerTest.java | 59 +++- .../build/lib/remote/GrpcCacheClientTest.java | 14 +- .../lib/remote/GrpcCacheClientTestExtra.java | 107 +++++++ ...SpawnRunnerWithGrpcRemoteExecutorTest.java | 9 +- .../devtools/build/lib/remote/zstd/BUILD | 28 ++ .../zstd/ZstdCompressingInputStreamTest.java | 54 ++++ .../ZstdDecompressingOutputStreamTest.java | 43 +++ ...stdDecompressingOutputStreamTestExtra.java | 70 +++++ .../shell/integration/minimal_jdk_test.sh | 6 +- third_party/zstd-jni/Native.java.patch | 11 + third_party/zstd-jni/zstd-jni.BUILD | 62 ++++ 25 files changed, 1093 insertions(+), 86 deletions(-) create mode 100644 src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD create mode 100644 src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStream.java create mode 100644 src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/zstd/BUILD create mode 100644 src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStreamTest.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTest.java create mode 100644 src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java create mode 100644 third_party/zstd-jni/Native.java.patch create mode 100644 third_party/zstd-jni/zstd-jni.BUILD diff --git a/BUILD b/BUILD index 3f42b446918620..a8eb1e7d40bcd3 100644 --- a/BUILD +++ b/BUILD @@ -81,6 +81,7 @@ pkg_tar( "@com_google_protobuf//:protobuf_java", "@com_google_protobuf//:protobuf_java_util", "@com_google_protobuf//:protobuf_javalite", + "@zstd-jni//:zstd-jni", ], package_dir = "derived/jars", strip_prefix = "external", diff --git a/WORKSPACE b/WORKSPACE index 14da81f563f04f..464a7a6f5a4c2b 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -301,6 +301,14 @@ dist_http_archive( patch_cmds_win = EXPORT_WORKSPACE_IN_BUILD_FILE_WIN, ) +dist_http_archive( + name = "zstd-jni", + patch_cmds = EXPORT_WORKSPACE_IN_BUILD_BAZEL_FILE, + patch_cmds_win = EXPORT_WORKSPACE_IN_BUILD_BAZEL_FILE_WIN, + build_file = "//third_party:zstd-jni/zstd-jni.BUILD", + strip_prefix = "zstd-jni-1.5.0-4" +) + http_archive( name = "org_snakeyaml", build_file_content = """ diff --git a/distdir_deps.bzl b/distdir_deps.bzl index d26ac0e145a2e6..10614b868d0a5b 100644 --- a/distdir_deps.bzl +++ b/distdir_deps.bzl @@ -172,6 +172,21 @@ DIST_DEPS = { "test_WORKSPACE_files", ], }, + "zstd-jni": { + "archive": "v1.5.0-4.zip", + "patch_args": ["-p1"], + "patches": [ + "//third_party:zstd-jni/Native.java.patch", + ], + "sha256": "d320d59b89a163c5efccbe4915ae6a49883ce653cdc670643dfa21c6063108e4", + "urls": [ + "https://mirror.bazel.build/github.com/luben/zstd-jni/archive/v1.5.0-4.zip", + "https://github.com/luben/zstd-jni/archive/v1.5.0-4.zip", + ], + "used_in": [ + "additional_distfiles", + ], + }, ################################################### # # Build time dependencies for testing and packaging diff --git a/src/main/java/com/google/devtools/build/lib/remote/BUILD b/src/main/java/com/google/devtools/build/lib/remote/BUILD index e2e1b987c3d6f1..562053eec62ba9 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/main/java/com/google/devtools/build/lib/remote/BUILD @@ -16,6 +16,7 @@ filegroup( "//src/main/java/com/google/devtools/build/lib/remote/merkletree:srcs", "//src/main/java/com/google/devtools/build/lib/remote/options:srcs", "//src/main/java/com/google/devtools/build/lib/remote/util:srcs", + "//src/main/java/com/google/devtools/build/lib/remote/zstd:srcs", ], visibility = ["//src:__subpackages__"], ) @@ -81,6 +82,7 @@ java_library( "//src/main/java/com/google/devtools/build/lib/remote/merkletree", "//src/main/java/com/google/devtools/build/lib/remote/options", "//src/main/java/com/google/devtools/build/lib/remote/util", + "//src/main/java/com/google/devtools/build/lib/remote/zstd", "//src/main/java/com/google/devtools/build/lib/sandbox", "//src/main/java/com/google/devtools/build/lib/skyframe:mutable_supplier", "//src/main/java/com/google/devtools/build/lib/skyframe:tree_artifact_value", @@ -94,6 +96,7 @@ java_library( "//src/main/java/com/google/devtools/build/lib/vfs:pathfragment", "//src/main/java/com/google/devtools/common/options", "//src/main/protobuf:failure_details_java_proto", + "//third_party:apache_commons_compress", "//third_party:auth", "//third_party:caffeine", "//third_party:flogger", diff --git a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java index 28e84497bf89d1..c488f14f397d07 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java +++ b/src/main/java/com/google/devtools/build/lib/remote/ByteStreamUploader.java @@ -15,6 +15,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static java.lang.String.format; import static java.util.Collections.singletonMap; import static java.util.concurrent.TimeUnit.SECONDS; @@ -298,9 +299,11 @@ boolean uploadsInProgress() { } } - private static String buildUploadResourceName(String instanceName, UUID uuid, Digest digest) { - String resourceName = - format("uploads/%s/blobs/%s/%d", uuid, digest.getHash(), digest.getSizeBytes()); + private static String buildUploadResourceName( + String instanceName, UUID uuid, Digest digest, boolean compressed) { + String template = + compressed ? "uploads/%s/compressed-blobs/zstd/%s/%d" : "uploads/%s/blobs/%s/%d"; + String resourceName = format(template, uuid, digest.getHash(), digest.getSizeBytes()); if (!Strings.isNullOrEmpty(instanceName)) { resourceName = instanceName + "/" + resourceName; } @@ -325,7 +328,8 @@ private ListenableFuture startAsyncUpload( } UUID uploadId = UUID.randomUUID(); - String resourceName = buildUploadResourceName(instanceName, uploadId, digest); + String resourceName = + buildUploadResourceName(instanceName, uploadId, digest, chunker.isCompressed()); AsyncUpload newUpload = new AsyncUpload( context, @@ -405,7 +409,20 @@ ListenableFuture start() { () -> retrier.executeAsync( () -> { - if (committedOffset.get() < chunker.getSize()) { + if (chunker.getSize() == 0) { + return immediateVoidFuture(); + } + try { + chunker.seek(committedOffset.get()); + } catch (IOException e) { + try { + chunker.reset(); + } catch (IOException resetException) { + e.addSuppressed(resetException); + } + return Futures.immediateFailedFuture(e); + } + if (chunker.hasNext()) { return callAndQueryOnFailure(committedOffset, progressiveBackoff); } return Futures.immediateFuture(null); @@ -416,13 +433,19 @@ ListenableFuture start() { return Futures.transformAsync( callFuture, (result) -> { - long committedSize = committedOffset.get(); - long expected = chunker.getSize(); - if (committedSize != expected) { - String message = - format( - "write incomplete: committed_size %d for %d total", committedSize, expected); - return Futures.immediateFailedFuture(new IOException(message)); + if (!chunker.hasNext()) { + // Only check for matching committed size if we have completed the upload. + // If another client did, they might have used a different compression + // level/algorithm, so we cannot know the expected committed offset + long committedSize = committedOffset.get(); + long expected = chunker.getOffset(); + if (!chunker.hasNext() && committedSize != expected) { + String message = + format( + "write incomplete: committed_size %d for %d total", + committedSize, expected); + return Futures.immediateFailedFuture(new IOException(message)); + } } return Futures.immediateFuture(null); }, @@ -517,17 +540,6 @@ private ListenableFuture call(AtomicLong committedOffset) { .withDeadlineAfter(callTimeoutSecs, SECONDS); call = channel.newCall(ByteStreamGrpc.getWriteMethod(), callOptions); - try { - chunker.seek(committedOffset.get()); - } catch (IOException e) { - try { - chunker.reset(); - } catch (IOException resetException) { - e.addSuppressed(resetException); - } - return Futures.immediateFailedFuture(e); - } - SettableFuture uploadResult = SettableFuture.create(); ClientCall.Listener callListener = new ClientCall.Listener() { diff --git a/src/main/java/com/google/devtools/build/lib/remote/Chunker.java b/src/main/java/com/google/devtools/build/lib/remote/Chunker.java index 7ce80ee24b8294..d1024a3d3143cf 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/Chunker.java +++ b/src/main/java/com/google/devtools/build/lib/remote/Chunker.java @@ -16,6 +16,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static java.lang.Math.min; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; @@ -23,12 +24,13 @@ import com.google.devtools.build.lib.actions.ActionInput; import com.google.devtools.build.lib.actions.ActionInputHelper; import com.google.devtools.build.lib.actions.cache.VirtualActionInput; +import com.google.devtools.build.lib.remote.zstd.ZstdCompressingInputStream; import com.google.devtools.build.lib.vfs.Path; import com.google.protobuf.ByteString; import java.io.ByteArrayInputStream; -import java.io.EOFException; import java.io.IOException; import java.io.InputStream; +import java.io.PushbackInputStream; import java.util.NoSuchElementException; import java.util.Objects; import java.util.function.Supplier; @@ -55,6 +57,10 @@ static int getDefaultChunkSize() { return defaultChunkSize; } + public boolean isCompressed() { + return compressed; + } + /** A piece of a byte[] blob. */ public static final class Chunk { @@ -98,19 +104,22 @@ public int hashCode() { private final int chunkSize; private final Chunk emptyChunk; - private InputStream data; + private ChunkerInputStream data; private long offset; private byte[] chunkCache; + private final boolean compressed; + // Set to true on the first call to next(). This is so that the Chunker can open its data source // lazily on the first call to next(), as opposed to opening it in the constructor or on reset(). private boolean initialized; - Chunker(Supplier dataSupplier, long size, int chunkSize) { + Chunker(Supplier dataSupplier, long size, int chunkSize, boolean compressed) { this.dataSupplier = checkNotNull(dataSupplier); this.size = size; this.chunkSize = chunkSize; this.emptyChunk = new Chunk(ByteString.EMPTY, 0); + this.compressed = compressed; } public long getOffset() { @@ -127,13 +136,9 @@ public long getSize() { *

Closes any open resources (file handles, ...). */ public void reset() throws IOException { - if (data != null) { - data.close(); - } - data = null; + close(); offset = 0; initialized = false; - chunkCache = null; } /** @@ -148,6 +153,9 @@ public void seek(long toOffset) throws IOException { maybeInitialize(); ByteStreams.skipFully(data, toOffset - offset); offset = toOffset; + if (data.finished()) { + close(); + } } /** @@ -157,6 +165,27 @@ public boolean hasNext() { return data != null || !initialized; } + /** Closes the input stream and reset chunk cache */ + private void close() throws IOException { + if (data != null) { + data.close(); + data = null; + } + chunkCache = null; + } + + /** Attempts reading at most a full chunk and stores it in the chunkCache buffer */ + private int read() throws IOException { + int count = 0; + while (count < chunkCache.length) { + int c = data.read(chunkCache, count, chunkCache.length - count); + if (c < 0) { + break; + } + count += c; + } + return count; + } /** * Returns the next {@link Chunk} or throws a {@link NoSuchElementException} if no data is left. * @@ -178,46 +207,40 @@ public Chunk next() throws IOException { return emptyChunk; } - // The cast to int is safe, because the return value is capped at chunkSize. - int bytesToRead = (int) Math.min(bytesLeft(), chunkSize); - if (bytesToRead == 0) { + if (data.finished()) { chunkCache = null; data = null; throw new NoSuchElementException(); } if (chunkCache == null) { + // If the output is compressed we can't know how many bytes there are yet to read, + // so we allocate the whole chunkSize, otherwise we try to compute the smallest possible value + // The cast to int is safe, because the return value is capped at chunkSize. + int cacheSize = compressed ? chunkSize : (int) min(getSize() - getOffset(), chunkSize); // Lazily allocate it in order to save memory on small data. // 1) bytesToRead < chunkSize: There will only ever be one next() call. // 2) bytesToRead == chunkSize: chunkCache will be set to its biggest possible value. // 3) bytestoRead > chunkSize: Not possible, due to Math.min above. - chunkCache = new byte[bytesToRead]; + chunkCache = new byte[cacheSize]; } long offsetBefore = offset; - try { - ByteStreams.readFully(data, chunkCache, 0, bytesToRead); - } catch (EOFException e) { - throw new IllegalStateException("Reached EOF, but expected " - + bytesToRead + " bytes.", e); - } - offset += bytesToRead; - ByteString blob = ByteString.copyFrom(chunkCache, 0, bytesToRead); + int bytesRead = read(); - if (bytesLeft() == 0) { - data.close(); - data = null; - chunkCache = null; + ByteString blob = ByteString.copyFrom(chunkCache, 0, bytesRead); + + // This has to happen after actualSize has been updated + // or the guard in getActualSize won't work. + offset += bytesRead; + if (data.finished()) { + close(); } return new Chunk(blob, offsetBefore); } - public long bytesLeft() { - return getSize() - getOffset(); - } - private void maybeInitialize() throws IOException { if (initialized) { return; @@ -226,7 +249,10 @@ private void maybeInitialize() throws IOException { checkState(offset == 0); checkState(chunkCache == null); try { - data = dataSupplier.get(); + data = + compressed + ? new ChunkerInputStream(new ZstdCompressingInputStream(dataSupplier.get())) + : new ChunkerInputStream(dataSupplier.get()); } catch (RuntimeException e) { Throwables.propagateIfPossible(e.getCause(), IOException.class); throw e; @@ -242,6 +268,7 @@ public static Builder builder() { public static class Builder { private int chunkSize = getDefaultChunkSize(); private long size; + private boolean compressed; private Supplier inputStream; public Builder setInput(byte[] data) { @@ -251,6 +278,11 @@ public Builder setInput(byte[] data) { return this; } + public Builder setCompressed(boolean compressed) { + this.compressed = compressed; + return this; + } + public Builder setInput(long size, InputStream in) { checkState(inputStream == null); checkNotNull(in); @@ -305,7 +337,22 @@ public Builder setChunkSize(int chunkSize) { public Chunker build() { checkNotNull(inputStream); - return new Chunker(inputStream, size, chunkSize); + return new Chunker(inputStream, size, chunkSize, compressed); + } + } + + static class ChunkerInputStream extends PushbackInputStream { + ChunkerInputStream(InputStream in) { + super(in); + } + + public boolean finished() throws IOException { + int c = super.read(); + if (c == -1) { + return true; + } + super.unread(c); + return false; } } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java index aa163837ea4a7d..e35d4c6f32ce94 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java +++ b/src/main/java/com/google/devtools/build/lib/remote/GrpcCacheClient.java @@ -53,6 +53,7 @@ import com.google.devtools.build.lib.remote.util.DigestUtil; import com.google.devtools.build.lib.remote.util.TracingMetadataUtils; import com.google.devtools.build.lib.remote.util.Utils; +import com.google.devtools.build.lib.remote.zstd.ZstdDecompressingOutputStream; import com.google.devtools.build.lib.vfs.Path; import com.google.protobuf.ByteString; import io.grpc.Status; @@ -68,6 +69,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import javax.annotation.Nullable; +import org.apache.commons.compress.utils.CountingOutputStream; /** A RemoteActionCache implementation that uses gRPC calls to a remote cache server. */ @ThreadSafe @@ -294,13 +296,24 @@ public ListenableFuture downloadBlob( out = digestOut; } - return downloadBlob(context, digest, out, digestSupplier); + CountingOutputStream outputStream; + if (options.cacheCompression) { + try { + outputStream = new ZstdDecompressingOutputStream(out); + } catch (IOException e) { + return Futures.immediateFailedFuture(e); + } + } else { + outputStream = new CountingOutputStream(out); + } + + return downloadBlob(context, digest, outputStream, digestSupplier); } private ListenableFuture downloadBlob( RemoteActionExecutionContext context, Digest digest, - OutputStream out, + CountingOutputStream out, @Nullable Supplier digestSupplier) { AtomicLong offset = new AtomicLong(0); ProgressiveBackoff progressiveBackoff = new ProgressiveBackoff(retrier::newBackoff); @@ -321,12 +334,13 @@ private ListenableFuture downloadBlob( MoreExecutors.directExecutor()); } - public static String getResourceName(String instanceName, Digest digest) { + public static String getResourceName(String instanceName, Digest digest, boolean compressed) { String resourceName = ""; if (!instanceName.isEmpty()) { resourceName += instanceName + "/"; } - return resourceName + "blobs/" + DigestUtil.toString(digest); + resourceName += compressed ? "compressed-blobs/zstd/" : "blobs/"; + return resourceName + DigestUtil.toString(digest); } private ListenableFuture requestRead( @@ -334,9 +348,10 @@ private ListenableFuture requestRead( AtomicLong offset, ProgressiveBackoff progressiveBackoff, Digest digest, - OutputStream out, + CountingOutputStream out, @Nullable Supplier digestSupplier) { - String resourceName = getResourceName(options.remoteInstanceName, digest); + String resourceName = + getResourceName(options.remoteInstanceName, digest, options.cacheCompression); SettableFuture future = SettableFuture.create(); bsAsyncStub(context) .read( @@ -345,12 +360,13 @@ private ListenableFuture requestRead( .setReadOffset(offset.get()) .build(), new StreamObserver() { + @Override public void onNext(ReadResponse readResponse) { ByteString data = readResponse.getData(); try { data.writeTo(out); - offset.addAndGet(data.size()); + offset.set(out.getBytesWritten()); } catch (IOException e) { // Cancel the call. throw new RuntimeException(e); @@ -402,7 +418,10 @@ public ListenableFuture uploadFile( return uploader.uploadBlobAsync( context, digest, - Chunker.builder().setInput(digest.getSizeBytes(), path).build(), + Chunker.builder() + .setInput(digest.getSizeBytes(), path) + .setCompressed(options.cacheCompression) + .build(), /* forceUpload= */ true); } @@ -412,7 +431,10 @@ public ListenableFuture uploadBlob( return uploader.uploadBlobAsync( context, digest, - Chunker.builder().setInput(data.toByteArray()).build(), + Chunker.builder() + .setInput(data.toByteArray()) + .setCompressed(options.cacheCompression) + .build(), /* forceUpload= */ true); } } diff --git a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java index 417d62ade2dff3..6eb03ceb559b87 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java +++ b/src/main/java/com/google/devtools/build/lib/remote/RemoteServerCapabilities.java @@ -17,6 +17,7 @@ import build.bazel.remote.execution.v2.CacheCapabilities; import build.bazel.remote.execution.v2.CapabilitiesGrpc; import build.bazel.remote.execution.v2.CapabilitiesGrpc.CapabilitiesBlockingStub; +import build.bazel.remote.execution.v2.Compressor; import build.bazel.remote.execution.v2.DigestFunction; import build.bazel.remote.execution.v2.ExecutionCapabilities; import build.bazel.remote.execution.v2.GetCapabilitiesRequest; @@ -249,6 +250,13 @@ public static ClientServerCompatibilityStatus checkClientServerCompatibility( } } + if (remoteOptions.cacheCompression + && !cacheCap.getSupportedCompressorsList().contains(Compressor.Value.ZSTD)) { + result.addError( + "--experimental_remote_cache_compression requested but remote does not support" + + " compression"); + } + // Check result cache priority is in the supported range. checkPriorityInRange( remoteOptions.remoteResultCachePriority, diff --git a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java index 89d6e1c6f5120f..1d2586af78df5e 100644 --- a/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java +++ b/src/main/java/com/google/devtools/build/lib/remote/options/RemoteOptions.java @@ -364,6 +364,14 @@ public String getTypeDescription() { + "symlinks and represent them as files. See #6631 for details.") public boolean incompatibleRemoteSymlinks; + @Option( + name = "experimental_remote_cache_compression", + defaultValue = "false", + documentationCategory = OptionDocumentationCategory.REMOTE, + effectTags = {OptionEffectTag.UNKNOWN}, + help = "If enabled, compress/decompress cache blobs with zstd.") + public boolean cacheCompression; + @Option( name = "build_event_upload_max_threads", defaultValue = "100", diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD new file mode 100644 index 00000000000000..6108cddc569f03 --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/BUILD @@ -0,0 +1,24 @@ +load("@rules_java//java:defs.bzl", "java_library") + +package( + default_visibility = ["//src:__subpackages__"], +) + +licenses(["notice"]) + +filegroup( + name = "srcs", + srcs = glob(["*"]), + visibility = ["//src:__subpackages__"], +) + +java_library( + name = "zstd", + srcs = glob(["*.java"]), + deps = [ + "//third_party:apache_commons_compress", + "//third_party:guava", + "//third_party/protobuf:protobuf_java", + "@zstd-jni", + ], +) diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStream.java b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStream.java new file mode 100644 index 00000000000000..2f6396f6799c5e --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStream.java @@ -0,0 +1,104 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.zstd; + +import static java.lang.Math.max; + +import com.github.luben.zstd.ZstdOutputStream; +import com.google.common.base.Preconditions; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; + +/** A {@link FilterInputStream} that use zstd to compress the content. */ +public class ZstdCompressingInputStream extends FilterInputStream { + // We want the buffer to be able to contain at least: + // - Magic number: 4 bytes + // - FrameHeader 14 bytes + // - Block Header: 3 bytes + // - First block byte + // This guarantees that we can always compress at least + // 1 byte and write it to the pipe without blocking. + public static final int MIN_BUFFER_SIZE = 4 + 14 + 3 + 1; + + private final PipedInputStream pis; + private ZstdOutputStream zos; + private final int size; + + public ZstdCompressingInputStream(InputStream in) throws IOException { + this(in, 512); + } + + ZstdCompressingInputStream(InputStream in, int size) throws IOException { + super(in); + Preconditions.checkArgument( + size >= MIN_BUFFER_SIZE, + String.format("The buffer size must be at least %d bytes", MIN_BUFFER_SIZE)); + this.size = size; + this.pis = new PipedInputStream(size); + this.zos = new ZstdOutputStream(new PipedOutputStream(pis)); + } + + private void reFill() throws IOException { + byte[] buf = new byte[size]; + int len = super.read(buf, 0, max(0, size - pis.available() - MIN_BUFFER_SIZE + 1)); + if (len == -1) { + zos.close(); + zos = null; + } else { + zos.write(buf, 0, len); + zos.flush(); + } + } + + @Override + public int read() throws IOException { + if (pis.available() == 0) { + if (zos == null) { + return -1; + } + reFill(); + } + return pis.read(); + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int count = 0; + int n = len > 0 ? -1 : 0; + while (count < len && (pis.available() > 0 || zos != null)) { + if (pis.available() == 0) { + reFill(); + } + n = pis.read(b, count + off, len - count); + count += max(0, n); + } + return count > 0 ? count : n; + } + + @Override + public void close() throws IOException { + if (zos != null) { + zos.close(); + } + in.close(); + } +} diff --git a/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java new file mode 100644 index 00000000000000..ad1c333320964c --- /dev/null +++ b/src/main/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStream.java @@ -0,0 +1,63 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.zstd; + +import com.github.luben.zstd.ZstdInputStream; +import com.google.protobuf.ByteString; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.apache.commons.compress.utils.CountingOutputStream; + +/** A {@link CountingOutputStream} that use zstd to decompress the content. */ +public class ZstdDecompressingOutputStream extends CountingOutputStream { + private ByteArrayInputStream inner; + private final ZstdInputStream zis; + + public ZstdDecompressingOutputStream(OutputStream out) throws IOException { + super(out); + zis = + new ZstdInputStream( + new InputStream() { + @Override + public int read() { + return inner.read(); + } + + @Override + public int read(byte[] b, int off, int len) { + return inner.read(b, off, len); + } + }); + zis.setContinuous(true); + } + + @Override + public void write(int b) throws IOException { + write(new byte[] {(byte) b}, 0, 1); + } + + @Override + public void write(byte[] b) throws IOException { + write(b, 0, b.length); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + inner = new ByteArrayInputStream(b, off, len); + byte[] data = ByteString.readFrom(zis).toByteArray(); + super.write(data, 0, data.length); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/BUILD b/src/test/java/com/google/devtools/build/lib/remote/BUILD index 3adb51704b6343..34a317a3915a1e 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/BUILD +++ b/src/test/java/com/google/devtools/build/lib/remote/BUILD @@ -16,6 +16,7 @@ filegroup( "//src/test/java/com/google/devtools/build/lib/remote/merkletree:srcs", "//src/test/java/com/google/devtools/build/lib/remote/options:srcs", "//src/test/java/com/google/devtools/build/lib/remote/util:srcs", + "//src/test/java/com/google/devtools/build/lib/remote/zstd:srcs", ], visibility = ["//src/test/java/com/google/devtools/build/lib:__pkg__"], ) @@ -115,5 +116,6 @@ java_test( "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_grpc", "@remoteapis//:build_bazel_remote_execution_v2_remote_execution_java_proto", "@remoteapis//:build_bazel_semver_semver_java_proto", + "@zstd-jni//:zstd-jni", ], ) diff --git a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java index 748a3417db5c25..15cc335a0ace40 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ByteStreamUploaderTest.java @@ -21,6 +21,8 @@ import build.bazel.remote.execution.v2.Digest; import build.bazel.remote.execution.v2.RequestMetadata; +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdInputStream; import com.google.bytestream.ByteStreamGrpc; import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase; import com.google.bytestream.ByteStreamProto.QueryWriteStatusRequest; @@ -63,6 +65,7 @@ import io.grpc.util.MutableHandlerRegistry; import io.reactivex.rxjava3.core.Single; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -341,6 +344,130 @@ public void queryWriteStatus( blockUntilInternalStateConsistent(uploader); } + @Test + public void progressiveCompressedUploadShouldWork() throws Exception { + Mockito.when(mockBackoff.getRetryAttempts()).thenReturn(0); + RemoteRetrier retrier = + TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService); + ByteStreamUploader uploader = + new ByteStreamUploader( + INSTANCE_NAME, + new ReferenceCountedChannel(channelConnectionFactory), + CallCredentialsProvider.NO_CREDENTIALS, + 300, + retrier); + + byte[] blob = new byte[CHUNK_SIZE * 2 + 1]; + new Random().nextBytes(blob); + + Chunker chunker = + Chunker.builder().setInput(blob).setCompressed(true).setChunkSize(CHUNK_SIZE).build(); + HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash()); + + while (chunker.hasNext()) { + chunker.next(); + } + long expectedSize = chunker.getOffset(); + chunker.reset(); + + serviceRegistry.addService( + new ByteStreamImplBase() { + + byte[] receivedData = new byte[(int) expectedSize]; + String receivedResourceName = null; + boolean receivedComplete = false; + long nextOffset = 0; + long initialOffset = 0; + boolean mustQueryWriteStatus = false; + + @Override + public StreamObserver write(StreamObserver streamObserver) { + return new StreamObserver() { + @Override + public void onNext(WriteRequest writeRequest) { + assertThat(mustQueryWriteStatus).isFalse(); + + String resourceName = writeRequest.getResourceName(); + if (nextOffset == initialOffset) { + if (initialOffset == 0) { + receivedResourceName = resourceName; + } + assertThat(resourceName).startsWith(INSTANCE_NAME + "/uploads"); + assertThat(resourceName).endsWith(String.valueOf(blob.length)); + } else { + assertThat(resourceName).isEmpty(); + } + + assertThat(writeRequest.getWriteOffset()).isEqualTo(nextOffset); + + ByteString data = writeRequest.getData(); + + System.arraycopy( + data.toByteArray(), 0, receivedData, (int) nextOffset, data.size()); + + nextOffset += data.size(); + receivedComplete = expectedSize == nextOffset; + assertThat(writeRequest.getFinishWrite()).isEqualTo(receivedComplete); + + if (initialOffset == 0) { + streamObserver.onError(Status.DEADLINE_EXCEEDED.asException()); + mustQueryWriteStatus = true; + initialOffset = nextOffset; + } + } + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + assertThat(nextOffset).isEqualTo(expectedSize); + byte[] decompressed = Zstd.decompress(receivedData, blob.length); + assertThat(decompressed).isEqualTo(blob); + + WriteResponse response = + WriteResponse.newBuilder().setCommittedSize(nextOffset).build(); + streamObserver.onNext(response); + streamObserver.onCompleted(); + } + }; + } + + @Override + public void queryWriteStatus( + QueryWriteStatusRequest request, StreamObserver response) { + String resourceName = request.getResourceName(); + final long committedSize; + final boolean complete; + if (receivedResourceName != null && receivedResourceName.equals(resourceName)) { + assertThat(mustQueryWriteStatus).isTrue(); + mustQueryWriteStatus = false; + committedSize = nextOffset; + complete = receivedComplete; + } else { + committedSize = 0; + complete = false; + } + response.onNext( + QueryWriteStatusResponse.newBuilder() + .setCommittedSize(committedSize) + .setComplete(complete) + .build()); + response.onCompleted(); + } + }); + + uploader.uploadBlob(context, hash, chunker, true); + + // This test should not have triggered any retries. + Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class)); + Mockito.verify(mockBackoff, Mockito.times(1)).getRetryAttempts(); + + blockUntilInternalStateConsistent(uploader); + } + @Test public void concurrentlyCompletedUploadIsNotRetried() throws Exception { // Test that after an upload has failed and the QueryWriteStatus call returns @@ -512,7 +639,7 @@ public StreamObserver write(StreamObserver streamOb } @Test - public void incorrectCommittedSizeFailsUpload() throws Exception { + public void incorrectCommittedSizeFailsCompletedUpload() throws Exception { RemoteRetrier retrier = TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService); ByteStreamUploader uploader = @@ -533,10 +660,23 @@ public void incorrectCommittedSizeFailsUpload() throws Exception { new ByteStreamImplBase() { @Override public StreamObserver write(StreamObserver streamObserver) { - streamObserver.onNext( - WriteResponse.newBuilder().setCommittedSize(blob.length + 1).build()); - streamObserver.onCompleted(); - return new NoopStreamObserver(); + return new StreamObserver() { + @Override + public void onNext(WriteRequest writeRequest) {} + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + WriteResponse response = + WriteResponse.newBuilder().setCommittedSize(blob.length + 1).build(); + streamObserver.onNext(response); + streamObserver.onCompleted(); + } + }; } }); @@ -553,6 +693,38 @@ public StreamObserver write(StreamObserver streamOb blockUntilInternalStateConsistent(uploader); } + @Test + public void incorrectCommittedSizeDoesNotFailsIncompleteUpload() throws Exception { + RemoteRetrier retrier = + TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService); + ByteStreamUploader uploader = + new ByteStreamUploader( + INSTANCE_NAME, + new ReferenceCountedChannel(channelConnectionFactory), + CallCredentialsProvider.NO_CREDENTIALS, + 300, + retrier); + + byte[] blob = new byte[CHUNK_SIZE * 2 + 1]; + new Random().nextBytes(blob); + + Chunker chunker = Chunker.builder().setInput(blob).setChunkSize(CHUNK_SIZE).build(); + HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash()); + + serviceRegistry.addService( + new ByteStreamImplBase() { + @Override + public StreamObserver write(StreamObserver streamObserver) { + streamObserver.onNext(WriteResponse.newBuilder().setCommittedSize(CHUNK_SIZE).build()); + streamObserver.onCompleted(); + return new NoopStreamObserver(); + } + }); + + uploader.uploadBlob(context, hash, chunker, true); + blockUntilInternalStateConsistent(uploader); + } + @Test public void multipleBlobsUploadShouldWork() throws Exception { RemoteRetrier retrier = @@ -1345,6 +1517,99 @@ public void onCompleted() { blockUntilInternalStateConsistent(uploader); } + @Test + public void testCompressedUploads() throws Exception { + RemoteRetrier retrier = + TestUtils.newRemoteRetrier(() -> mockBackoff, (e) -> true, retryService); + ByteStreamUploader uploader = + new ByteStreamUploader( + INSTANCE_NAME, + new ReferenceCountedChannel(channelConnectionFactory), + CallCredentialsProvider.NO_CREDENTIALS, + /* callTimeoutSecs= */ 60, + retrier); + + byte[] blob = new byte[CHUNK_SIZE * 2 + 1]; + new Random().nextBytes(blob); + + AtomicInteger numUploads = new AtomicInteger(); + + serviceRegistry.addService( + new ByteStreamImplBase() { + @Override + public StreamObserver write(StreamObserver streamObserver) { + return new StreamObserver() { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + String resourceName = null; + + @Override + public void onNext(WriteRequest writeRequest) { + if (!writeRequest.getResourceName().isEmpty()) { + if (resourceName != null) { + assertThat(resourceName).isEqualTo(writeRequest.getResourceName()); + } else { + resourceName = writeRequest.getResourceName(); + assertThat(resourceName).contains("/compressed-blobs/zstd/"); + } + } + try { + writeRequest.getData().writeTo(baos); + if (writeRequest.getFinishWrite()) { + baos.close(); + } + } catch (IOException e) { + throw new AssertionError("I/O error on ByteArrayOutputStream.", e); + } + } + + @Override + public void onError(Throwable throwable) { + fail("onError should never be called."); + } + + @Override + public void onCompleted() { + byte[] data = baos.toByteArray(); + try { + ZstdInputStream zis = new ZstdInputStream(new ByteArrayInputStream(data)); + byte[] decompressed = ByteString.readFrom(zis).toByteArray(); + zis.close(); + Digest digest = DIGEST_UTIL.compute(decompressed); + + assertThat(blob).hasLength(decompressed.length); + assertThat(resourceName).isNotNull(); + assertThat(resourceName) + .endsWith(String.format("/%s/%s", digest.getHash(), digest.getSizeBytes())); + + numUploads.incrementAndGet(); + } catch (IOException e) { + throw new AssertionError("Failed decompressing data.", e); + } finally { + WriteResponse response = + WriteResponse.newBuilder().setCommittedSize(data.length).build(); + + streamObserver.onNext(response); + streamObserver.onCompleted(); + } + } + }; + } + }); + + Chunker chunker = + Chunker.builder().setInput(blob).setCompressed(true).setChunkSize(CHUNK_SIZE).build(); + HashCode hash = HashCode.fromString(DIGEST_UTIL.compute(blob).getHash()); + + uploader.uploadBlob(context, hash, chunker, true); + + // This test should not have triggered any retries. + Mockito.verifyNoInteractions(mockBackoff); + + blockUntilInternalStateConsistent(uploader); + + assertThat(numUploads.get()).isEqualTo(1); + } + private static class NoopStreamObserver implements StreamObserver { @Override public void onNext(WriteRequest writeRequest) {} diff --git a/src/test/java/com/google/devtools/build/lib/remote/ChunkerTest.java b/src/test/java/com/google/devtools/build/lib/remote/ChunkerTest.java index 7413d42b6015ca..ba85971620a32d 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/ChunkerTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/ChunkerTest.java @@ -16,6 +16,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.github.luben.zstd.Zstd; import com.google.devtools.build.lib.remote.Chunker.Chunk; import com.google.protobuf.ByteString; import java.io.ByteArrayInputStream; @@ -101,7 +102,7 @@ public void emptyData() throws Exception { @Test public void reset() throws Exception { - byte[] data = new byte[]{1, 2, 3}; + byte[] data = new byte[] {1, 2, 3}; Chunker chunker = Chunker.builder().setInput(data).setChunkSize(1).build(); assertNextEquals(chunker, (byte) 1); @@ -125,12 +126,13 @@ public void resourcesShouldBeReleased() throws IOException { byte[] data = new byte[] {1, 2}; final AtomicReference in = new AtomicReference<>(); - Supplier supplier = () -> { - in.set(Mockito.spy(new ByteArrayInputStream(data))); - return in.get(); - }; + Supplier supplier = + () -> { + in.set(Mockito.spy(new ByteArrayInputStream(data))); + return in.get(); + }; - Chunker chunker = new Chunker(supplier, data.length, 1); + Chunker chunker = new Chunker(supplier, data.length, 1, false); assertThat(in.get()).isNull(); assertNextEquals(chunker, (byte) 1); Mockito.verify(in.get(), Mockito.never()).close(); @@ -173,6 +175,51 @@ public void seekBackwards() throws IOException { assertThat(next.getData()).hasSize(8); } + @Test + public void testSingleChunkCompressed() throws IOException { + byte[] data = {72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, 33}; + Chunker chunker = + Chunker.builder().setInput(data).setChunkSize(data.length * 2).setCompressed(true).build(); + Chunk next = chunker.next(); + assertThat(chunker.hasNext()).isFalse(); + assertThat(Zstd.decompress(next.getData().toByteArray(), data.length)).isEqualTo(data); + } + + @Test + public void testMultiChunkCompressed() throws IOException { + byte[] data = {72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, 33}; + Chunker chunker = + Chunker.builder().setInput(data).setChunkSize(data.length / 2).setCompressed(true).build(); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + chunker.next().getData().writeTo(baos); + assertThat(chunker.hasNext()).isTrue(); + while (chunker.hasNext()) { + chunker.next().getData().writeTo(baos); + } + baos.close(); + + assertThat(Zstd.decompress(baos.toByteArray(), data.length)).isEqualTo(data); + } + + @Test + public void testActualSizeIsCorrectAfterSeek() throws IOException { + byte[] data = {72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100, 33}; + int[] expectedSizes = {12, 24}; + for (int expected : expectedSizes) { + Chunker chunker = + Chunker.builder() + .setInput(data) + .setChunkSize(data.length * 2) + .setCompressed(expected != data.length) + .build(); + chunker.seek(5); + chunker.next(); + assertThat(chunker.hasNext()).isFalse(); + assertThat(chunker.getOffset()).isEqualTo(expected); + } + } + private void assertNextEquals(Chunker chunker, byte... data) throws IOException { assertThat(chunker.hasNext()).isTrue(); ByteString next = chunker.next().getData(); diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java index cbb22d71d5a7a0..5975f5eecd67c2 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTest.java @@ -125,16 +125,16 @@ @RunWith(JUnit4.class) public class GrpcCacheClientTest { - private static final DigestUtil DIGEST_UTIL = new DigestUtil(DigestHashFunction.SHA256); + protected static final DigestUtil DIGEST_UTIL = new DigestUtil(DigestHashFunction.SHA256); private FileSystem fs; private Path execRoot; private FileOutErr outErr; private FakeActionInputFileCache fakeFileCache; - private final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); + protected final MutableHandlerRegistry serviceRegistry = new MutableHandlerRegistry(); private final String fakeServerName = "fake server for " + getClass(); private Server fakeServer; - private RemoteActionExecutionContext context; + protected RemoteActionExecutionContext context; private RemotePathResolver remotePathResolver; private ListeningScheduledExecutorService retryService; @@ -196,12 +196,12 @@ private GrpcCacheClient newClient() throws IOException { return newClient(Options.getDefaults(RemoteOptions.class)); } - private GrpcCacheClient newClient(RemoteOptions remoteOptions) throws IOException { + protected GrpcCacheClient newClient(RemoteOptions remoteOptions) throws IOException { return newClient(remoteOptions, () -> new ExponentialBackoff(remoteOptions)); } - private GrpcCacheClient newClient(RemoteOptions remoteOptions, Supplier backoffSupplier) - throws IOException { + protected GrpcCacheClient newClient( + RemoteOptions remoteOptions, Supplier backoffSupplier) throws IOException { AuthAndTLSOptions authTlsOptions = Options.getDefaults(AuthAndTLSOptions.class); authTlsOptions.useGoogleDefaultCredentials = true; authTlsOptions.googleCredentials = "/execroot/main/creds.json"; @@ -256,7 +256,7 @@ public int maxConcurrency() { channel.retain(), callCredentialsProvider, remoteOptions, retrier, DIGEST_UTIL, uploader); } - private static byte[] downloadBlob( + protected static byte[] downloadBlob( RemoteActionExecutionContext context, GrpcCacheClient cacheClient, Digest digest) throws IOException, InterruptedException { try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { diff --git a/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java new file mode 100644 index 00000000000000..51effa08170977 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/GrpcCacheClientTestExtra.java @@ -0,0 +1,107 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.mockito.ArgumentMatchers.any; + +import build.bazel.remote.execution.v2.Digest; +import com.github.luben.zstd.Zstd; +import com.google.bytestream.ByteStreamGrpc.ByteStreamImplBase; +import com.google.bytestream.ByteStreamProto.ReadRequest; +import com.google.bytestream.ByteStreamProto.ReadResponse; +import com.google.devtools.build.lib.remote.Retrier.Backoff; +import com.google.devtools.build.lib.remote.options.RemoteOptions; +import com.google.devtools.common.options.Options; +import com.google.protobuf.ByteString; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Arrays; +import org.junit.Test; +import org.mockito.Mockito; + +/** Extra tests for {@link GrpcCacheClient} that are not tested internally. */ +public class GrpcCacheClientTestExtra extends GrpcCacheClientTest { + + @Test + public void compressedDownloadBlobIsRetriedWithProgress() + throws IOException, InterruptedException { + Backoff mockBackoff = Mockito.mock(Backoff.class); + RemoteOptions options = Options.getDefaults(RemoteOptions.class); + options.cacheCompression = true; + final GrpcCacheClient client = newClient(options, () -> mockBackoff); + final Digest digest = DIGEST_UTIL.computeAsUtf8("abcdefg"); + ByteString blob = ByteString.copyFrom(Zstd.compress("abcdefg".getBytes(UTF_8))); + serviceRegistry.addService( + new ByteStreamImplBase() { + @Override + public void read(ReadRequest request, StreamObserver responseObserver) { + assertThat(request.getResourceName().contains(digest.getHash())).isTrue(); + int off = (int) request.getReadOffset(); + // Zstd header size is 9 bytes + ByteString data = off == 0 ? blob.substring(0, 9 + 1) : blob.substring(9 + off); + responseObserver.onNext(ReadResponse.newBuilder().setData(data).build()); + if (off == 0) { + responseObserver.onError(Status.DEADLINE_EXCEEDED.asException()); + } else { + responseObserver.onCompleted(); + } + } + }); + assertThat(new String(downloadBlob(context, client, digest), UTF_8)).isEqualTo("abcdefg"); + Mockito.verify(mockBackoff, Mockito.never()).nextDelayMillis(any(Exception.class)); + } + + @Test + public void testCompressedDownload() throws IOException, InterruptedException { + RemoteOptions options = Options.getDefaults(RemoteOptions.class); + options.cacheCompression = true; + final GrpcCacheClient client = newClient(options); + final byte[] data = "abcdefg".getBytes(UTF_8); + final Digest digest = DIGEST_UTIL.compute(data); + final byte[] compressed = Zstd.compress(data); + + serviceRegistry.addService( + new ByteStreamImplBase() { + @Override + public void read(ReadRequest request, StreamObserver responseObserver) { + assertThat(request.getResourceName().contains(digest.getHash())).isTrue(); + responseObserver.onNext( + ReadResponse.newBuilder() + .setData( + ByteString.copyFrom( + Arrays.copyOfRange(compressed, 0, compressed.length / 3))) + .build()); + responseObserver.onNext( + ReadResponse.newBuilder() + .setData( + ByteString.copyFrom( + Arrays.copyOfRange( + compressed, compressed.length / 3, compressed.length / 3 * 2))) + .build()); + responseObserver.onNext( + ReadResponse.newBuilder() + .setData( + ByteString.copyFrom( + Arrays.copyOfRange( + compressed, compressed.length / 3 * 2, compressed.length))) + .build()); + responseObserver.onCompleted(); + } + }); + assertThat(downloadBlob(context, client, digest)).isEqualTo(data); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerWithGrpcRemoteExecutorTest.java b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerWithGrpcRemoteExecutorTest.java index 5529421cae8a65..817394ff742fd4 100644 --- a/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerWithGrpcRemoteExecutorTest.java +++ b/src/test/java/com/google/devtools/build/lib/remote/RemoteSpawnRunnerWithGrpcRemoteExecutorTest.java @@ -1075,7 +1075,8 @@ public void findMissingBlobs( responseObserver.onCompleted(); } }); - String stdOutResourceName = getResourceName(remoteOptions.remoteInstanceName, stdOutDigest); + String stdOutResourceName = + getResourceName(remoteOptions.remoteInstanceName, stdOutDigest, false); serviceRegistry.addService( new ByteStreamImplBase() { @Override @@ -1136,7 +1137,8 @@ public void findMissingBlobs( responseObserver.onCompleted(); } }); - String stdOutResourceName = getResourceName(remoteOptions.remoteInstanceName, stdOutDigest); + String stdOutResourceName = + getResourceName(remoteOptions.remoteInstanceName, stdOutDigest, false); serviceRegistry.addService( new ByteStreamImplBase() { @Override @@ -1262,7 +1264,8 @@ public void getActionResult( } }); String dummyTreeResourceName = - getResourceName(remoteOptions.remoteInstanceName, DUMMY_OUTPUT_DIRECTORY.getTreeDigest()); + getResourceName( + remoteOptions.remoteInstanceName, DUMMY_OUTPUT_DIRECTORY.getTreeDigest(), false); serviceRegistry.addService( new ByteStreamImplBase() { private boolean first = true; diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/BUILD b/src/test/java/com/google/devtools/build/lib/remote/zstd/BUILD new file mode 100644 index 00000000000000..2bb638bb4353f9 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/BUILD @@ -0,0 +1,28 @@ +load("@rules_java//java:defs.bzl", "java_test") + +package( + default_testonly = 1, + default_visibility = ["//src:__subpackages__"], +) + +licenses(["notice"]) + +filegroup( + name = "srcs", + testonly = 0, + srcs = glob(["**"]), + visibility = ["//src:__subpackages__"], +) + +java_test( + name = "zstd", + srcs = glob(["*.java"]), + test_class = "com.google.devtools.build.lib.AllTests", + deps = [ + "//src/main/java/com/google/devtools/build/lib/remote/zstd", + "//src/test/java/com/google/devtools/build/lib:test_runner", + "//third_party:junit4", + "//third_party:truth", + "@zstd-jni//:zstd-jni", + ], +) diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStreamTest.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStreamTest.java new file mode 100644 index 00000000000000..fb3a869ea03f7d --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdCompressingInputStreamTest.java @@ -0,0 +1,54 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.zstd; + +import static com.google.common.truth.Truth.assertThat; + +import com.github.luben.zstd.Zstd; +import com.google.common.io.ByteStreams; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.util.Random; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ZstdCompressingInputStream}. */ +@RunWith(JUnit4.class) +public class ZstdCompressingInputStreamTest { + @Test + public void compressionWorks() throws IOException { + Random rand = new Random(); + byte[] data = new byte[50]; + rand.nextBytes(data); + + ByteArrayInputStream bais = new ByteArrayInputStream(data); + ZstdCompressingInputStream zdis = new ZstdCompressingInputStream(bais); + + assertThat(Zstd.decompress(ByteStreams.toByteArray(zdis), data.length)).isEqualTo(data); + } + + @Test + public void streamCanBeCompressedWithMinimumBufferSize() throws IOException { + Random rand = new Random(); + byte[] data = new byte[50]; + rand.nextBytes(data); + + ByteArrayInputStream bais = new ByteArrayInputStream(data); + ZstdCompressingInputStream zdis = + new ZstdCompressingInputStream(bais, ZstdCompressingInputStream.MIN_BUFFER_SIZE); + + assertThat(Zstd.decompress(ByteStreams.toByteArray(zdis), data.length)).isEqualTo(data); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTest.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTest.java new file mode 100644 index 00000000000000..7a5dd9d211a71d --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTest.java @@ -0,0 +1,43 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.zstd; + +import static com.google.common.truth.Truth.assertThat; + +import com.github.luben.zstd.Zstd; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link ZstdDecompressingOutputStream}. */ +@RunWith(JUnit4.class) +public class ZstdDecompressingOutputStreamTest { + @Test + public void decompressionWorks() throws IOException { + Random rand = new Random(); + byte[] data = new byte[50]; + rand.nextBytes(data); + byte[] compressed = Zstd.compress(data); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ZstdDecompressingOutputStream zdos = new ZstdDecompressingOutputStream(baos); + zdos.write(compressed); + zdos.flush(); + + assertThat(baos.toByteArray()).isEqualTo(data); + } +} diff --git a/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java new file mode 100644 index 00000000000000..22cba85b8b6f68 --- /dev/null +++ b/src/test/java/com/google/devtools/build/lib/remote/zstd/ZstdDecompressingOutputStreamTestExtra.java @@ -0,0 +1,70 @@ +// Copyright 2021 The Bazel Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.google.devtools.build.lib.remote.zstd; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.github.luben.zstd.Zstd; +import com.github.luben.zstd.ZstdOutputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; +import org.junit.Test; + +/** Extra tests for {@link ZstdDecompressingOutputStream} that are not tested internally. */ +public class ZstdDecompressingOutputStreamTestExtra { + @Test + public void streamCanBeDecompressedOneByteAtATime() throws IOException { + Random rand = new Random(); + byte[] data = new byte[50]; + rand.nextBytes(data); + byte[] compressed = Zstd.compress(data); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ZstdDecompressingOutputStream zdos = new ZstdDecompressingOutputStream(baos); + for (byte b : compressed) { + zdos.write(b); + } + zdos.flush(); + + assertThat(baos.toByteArray()).isEqualTo(data); + } + + @Test + public void bytesWrittenMatchesDecompressedBytes() throws IOException { + byte[] data = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".getBytes(UTF_8); + + ByteArrayOutputStream compressed = new ByteArrayOutputStream(); + ZstdOutputStream zos = new ZstdOutputStream(compressed); + zos.setCloseFrameOnFlush(true); + for (int i = 0; i < data.length; i++) { + zos.write(data[i]); + if (i % 5 == 0) { + // Create multiple frames of 5 bytes each. + zos.flush(); + } + } + zos.close(); + + ByteArrayOutputStream decompressed = new ByteArrayOutputStream(); + ZstdDecompressingOutputStream zdos = new ZstdDecompressingOutputStream(decompressed); + for (byte b : compressed.toByteArray()) { + zdos.write(b); + zdos.flush(); + assertThat(zdos.getBytesWritten()).isEqualTo(decompressed.toByteArray().length); + } + assertThat(decompressed.toByteArray()).isEqualTo(data); + } +} diff --git a/src/test/shell/integration/minimal_jdk_test.sh b/src/test/shell/integration/minimal_jdk_test.sh index 0039e2e3e2073d..109d129dc3ad91 100755 --- a/src/test/shell/integration/minimal_jdk_test.sh +++ b/src/test/shell/integration/minimal_jdk_test.sh @@ -42,13 +42,13 @@ export BAZEL_SUFFIX="_jdk_minimal" source "$(rlocation "io_bazel/src/test/shell/integration_test_setup.sh")" \ || { echo "integration_test_setup.sh not found!" >&2; exit 1; } -# Bazel's install base is < 310MB with minimal JDK and > 315MB with an all +# Bazel's install base is < 311MB with minimal JDK and > 315MB with an all # modules JDK. -function test_size_less_than_310MB() { +function test_size_less_than_311MB() { bazel info ib=$(bazel info install_base) size=$(du -s "$ib" | cut -d\ -f1) - maxsize=$((1024*310)) + maxsize=$((1024*311)) if [ $size -gt $maxsize ]; then echo "$ib was too big:" 1>&2 du -a "$ib" 1>&2 diff --git a/third_party/zstd-jni/Native.java.patch b/third_party/zstd-jni/Native.java.patch new file mode 100644 index 00000000000000..4990d14a728b81 --- /dev/null +++ b/third_party/zstd-jni/Native.java.patch @@ -0,0 +1,11 @@ +--- a/src/main/java/com/github/luben/zstd/util/Native.java ++++ b/src/main/java/com/github/luben/zstd/util/Native.java +@@ -59,7 +59,7 @@ public enum Native { + if (loaded) { + return; + } +- String resourceName = resourceName(); ++ String resourceName = "/libzstd-jni.so"; + + String overridePath = System.getProperty(nativePathOverride); + if (overridePath != null) { diff --git a/third_party/zstd-jni/zstd-jni.BUILD b/third_party/zstd-jni/zstd-jni.BUILD new file mode 100644 index 00000000000000..6741c6b200b4f5 --- /dev/null +++ b/third_party/zstd-jni/zstd-jni.BUILD @@ -0,0 +1,62 @@ +cc_binary( + name = "libzstd-jni.so", + srcs = glob([ + "src/main/native/**/*.c", + "src/main/native/**/*.h", + ]) + select({ + "@io_bazel//src/conditions:windows": [ + "src/windows/include/jni_md.h", + "jni/jni.h", + ], + "//conditions:default": [ + "jni/jni_md.h", + "jni/jni.h", + ] + }), + copts = select({ + "@io_bazel//src/conditions:windows": [], + "//conditions:default": [ + "-std=c99", + "-Wno-unused-variable", + "-Wno-sometimes-uninitialized", + ] + }), + linkshared = 1, + includes = select({ + "@io_bazel//src/conditions:windows": ["src/windows/include"], + "//conditions:default": [], + }) + [ + "jni", + "src/main/native", + "src/main/native/common", + ], + local_defines = [ + "ZSTD_LEGACY_SUPPORT=4", + "ZSTD_MULTITHREAD=1", + ] + select({ + "@io_bazel//src/conditions:windows": ["_JNI_IMPLEMENTATION_"], + "//conditions:default": [], + }), +) + + +genrule( + name = "version-java", + cmd_bash = 'echo "package com.github.luben.zstd.util;\n\npublic class ZstdVersion {\n\tpublic static final String VERSION = \\"$$(cat $<)\\";\n}" > $@', + cmd_ps = '$$PSDefaultParameterValues.Remove("*:Encoding"); $$version = (Get-Content $<) -join ""; Set-Content -NoNewline -Path $@ -Value "package com.github.luben.zstd.util;\n\npublic class ZstdVersion {\n\tpublic static final String VERSION = `"$${version}`";\n}\n"', + srcs = ["version"], + outs = ["ZstdVersion.java"], +) + +java_library( + name = "zstd-jni", + srcs = glob([ + "src/main/java/**/*.java", + ]) + [ + ":version-java", + ], + resources = [":libzstd-jni.so"], + visibility = [ + "//visibility:public", + ], +)