Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ private void closeInput() throws IOException {
@Override
public void close() throws IOException {
reset();
blob.close();
}

/** Attempts reading at most a full chunk and stores it in the chunkCache buffer */
Expand All @@ -197,6 +196,7 @@ private int read() throws IOException {
}
return count;
}

/**
* Returns the next {@link Chunk} or throws a {@link NoSuchElementException} if no data is left.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import build.bazel.remote.execution.v2.Digest;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -48,7 +47,6 @@
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.RxUtils.TransferResult;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import io.reactivex.rxjava3.annotations.NonNull;
import io.reactivex.rxjava3.core.Completable;
Expand All @@ -63,8 +61,12 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -210,32 +212,37 @@ public ListenableFuture<Void> uploadVirtualActionInput(
context, digest, new VirtualActionInputBlob(virtualActionInput));
}

private static final class VirtualActionInputBlob implements Blob {
private VirtualActionInput virtualActionInput;
// Can be large compared to the retained size of the VirtualActionInput and thus shouldn't be
// kept in memory for an extended period of time.
private volatile ByteString data;

VirtualActionInputBlob(VirtualActionInput virtualActionInput) {
this.virtualActionInput = Preconditions.checkNotNull(virtualActionInput);
}
private record VirtualActionInputBlob(VirtualActionInput virtualActionInput) implements Blob {
private static final ExecutorService VIRTUAL_ACTION_INPUT_PIPE_EXECUTOR =
Executors.newThreadPerTaskExecutor(
Thread.ofVirtual().name("virtual-action-input-pipe-%d", 0).factory());

@Override
public InputStream get() throws IOException {
if (data == null) {
synchronized (this) {
if (data == null) {
data = Preconditions.checkNotNull(virtualActionInput, "used after close()").getBytes();
}
}
public InputStream get() {
// Avoid materializing and retaining VirtualActionInput.getBytes() during the upload. This
// can result in high memory usage with many parallel actions with large virtual inputs. Limit
// this memory usage to the fixed buffer size by using a piped stream.
var pipedIn = new PipedInputStream(Chunker.getDefaultChunkSize());
PipedOutputStream pipedOut;
try {
pipedOut = new PipedOutputStream(pipedIn);
} catch (IOException e) {
throw new IllegalStateException(
"PipedOutputStream constructor is not expected to throw", e);
}
return data.newInput();
}

@Override
public void close() {
virtualActionInput = null;
data = null;
// Note that while Piped{Input,Output}Stream are not directly I/O-bound, bytes read from
// pipedIn are sent out via gRPC before more bytes are read. As a result, pipedOut is expected
// to block frequently enough to make virtual threads suitable here.
VIRTUAL_ACTION_INPUT_PIPE_EXECUTOR.submit(
() -> {
try (pipedOut) {
virtualActionInput.writeTo(pipedOut);
} catch (IOException e) {
throw new IllegalStateException(
"writeTo is not expected to throw as pipedOut doesn't", e);
}
});
return pipedIn;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.devtools.build.lib.vfs.Path;
import com.google.protobuf.ByteString;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down Expand Up @@ -100,7 +99,7 @@ ListenableFuture<Void> downloadBlob(
* as late as possible and close the blob as soon as they are done with it.
*/
@FunctionalInterface
interface Blob extends Closeable {
interface Blob {
/** Get an input stream for the blob's data. Can be called multiple times. */
InputStream get() throws IOException;

Expand All @@ -109,9 +108,6 @@ interface Blob extends Closeable {
default String description() {
return null;
}

@Override
default void close() {}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -720,13 +720,9 @@ private void uploadAfterCredentialRefresh(UploadCommand upload, SettableFuture<V
public ListenableFuture<Void> uploadBlob(
RemoteActionExecutionContext context, Digest digest, Blob blob) {
return retrier.executeAsync(
() -> {
var result =
uploadAsync(
digest.getHash(), digest.getSizeBytes(), blob.get(), /* casUpload= */ true);
result.addListener(blob::close, MoreExecutors.directExecutor());
return result;
});
() ->
uploadAsync(
digest.getHash(), digest.getSizeBytes(), blob.get(), /* casUpload= */ true));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,17 @@
import com.google.bytestream.ByteStreamProto.ReadResponse;
import com.google.bytestream.ByteStreamProto.WriteRequest;
import com.google.bytestream.ByteStreamProto.WriteResponse;
import com.google.common.collect.ImmutableClassToInstanceMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.io.CountingOutputStream;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.devtools.build.lib.actions.ActionInputHelper;
import com.google.devtools.build.lib.actions.Spawn;
import com.google.devtools.build.lib.actions.cache.VirtualActionInput;
import com.google.devtools.build.lib.actions.util.ActionsTestUtil;
import com.google.devtools.build.lib.authandtls.AuthAndTLSOptions;
import com.google.devtools.build.lib.authandtls.CallCredentialsProvider;
import com.google.devtools.build.lib.authandtls.GoogleAuthUtils;
Expand All @@ -77,8 +76,8 @@
import com.google.devtools.build.lib.remote.merkletree.MerkleTree;
import com.google.devtools.build.lib.remote.merkletree.MerkleTreeComputer;
import com.google.devtools.build.lib.remote.options.RemoteOptions;
import com.google.devtools.build.lib.remote.util.DigestOutputStream;
import com.google.devtools.build.lib.remote.util.DigestUtil;
import com.google.devtools.build.lib.remote.util.FakeSpawnExecutionContext;
import com.google.devtools.build.lib.remote.util.TestUtils;
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
import com.google.devtools.build.lib.testutil.Scratch;
Expand Down Expand Up @@ -120,14 +119,17 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -292,34 +294,45 @@ public void testVirtualActionInputSupport() throws Exception {
/* symlinkTemplate= */ null,
DIGEST_UTIL);
PathFragment execPath = PathFragment.create("my/exec/path");
VirtualActionInput virtualActionInput =
ActionsTestUtil.createVirtualActionInput(execPath, "hello");
Spawn spawn =
new SpawnBuilder("unused").withInputs(virtualActionInput).withOutputs("foo").build();
SpawnExecutionContext spawnExecutionContext =
new FakeSpawnExecutionContext(
spawn,
/* inputMetadataProvider= */ null,
execRoot,
/* outErr= */ null,
ImmutableClassToInstanceMap.of(),
/* actionFileSystem= */ null);
var virtualActionInput =
new VirtualActionInput() {
@Override
public String getExecPathString() {
return execPath.getPathString();
}

@Override
public PathFragment getExecPath() {
return execPath;
}

@Override
public void writeTo(OutputStream out) throws IOException {
// Use a fixed seed to ensure deterministic content across multiple calls.
var random = new Random(123456);
// Use primes to exercise chunking logic. Keeping the full output in memory requires at
// least 64MB of heap.
for (int i = 0; i < 1031; i++) {
byte[] bytes = new byte[65537];
random.nextBytes(bytes);
out.write(bytes);
}
}
};
var merkleTreeComputer =
new MerkleTreeComputer(
DIGEST_UTIL, client, "buildRequestId", "commandId", TestConstants.WORKSPACE_NAME);
var spawn = new SpawnBuilder().withInput(virtualActionInput).build();
var merkleTree =
(MerkleTree.Uploadable)
new MerkleTreeComputer(
DIGEST_UTIL,
client,
"buildRequestId",
"commandId",
TestConstants.WORKSPACE_NAME)
.buildForSpawn(
spawn,
ImmutableSet.of(),
/* scrubber= */ null,
spawnExecutionContext,
remotePathResolver,
MerkleTreeComputer.BlobPolicy.KEEP);
Digest digest = DIGEST_UTIL.compute(virtualActionInput.getBytes().toByteArray());
merkleTreeComputer.buildForSpawn(
spawn,
ImmutableSet.of(),
/* scrubber= */ null,
context.getSpawnExecutionContext(),
remotePathResolver,
MerkleTreeComputer.BlobPolicy.KEEP);
Digest digest = DIGEST_UTIL.compute(virtualActionInput);

// Add a fake CAS that responds saying that the above virtual action input is missing
serviceRegistry.addService(
Expand All @@ -334,39 +347,96 @@ public void findMissingBlobs(
}
});

// Mock a byte stream and assert that we see the virtual action input with contents 'hello'
AtomicBoolean writeOccurred = new AtomicBoolean();
var serviceError = new AtomicReference<Throwable>();
var countingOut = new CountingOutputStream(OutputStream.nullOutputStream());
var digestOut =
new DigestOutputStream(DigestHashFunction.SHA256.getHashFunction(), countingOut);
var sawFinalChunk = new CountDownLatch(1);
var delayFinalChunk = new CountDownLatch(1);
serviceRegistry.addService(
new ByteStreamImplBase() {
@Override
public StreamObserver<WriteRequest> write(
final StreamObserver<WriteResponse> responseObserver) {
return new StreamObserver<WriteRequest>() {
return new StreamObserver<>() {
final AtomicBoolean firstRequest = new AtomicBoolean(true);

@Override
public void onNext(WriteRequest request) {
assertThat(request.getResourceName()).contains(digest.getHash());
assertThat(request.getFinishWrite()).isTrue();
assertThat(request.getData().toStringUtf8()).isEqualTo("hello");
writeOccurred.set(true);
try {
if (firstRequest.getAndSet(false)) {
assertThat(request.getResourceName()).contains(digest.getHash());
}
assertThat(request.getWriteOffset()).isEqualTo(countingOut.getCount());
try {
request.getData().newInput().transferTo(digestOut);
} catch (IOException e) {
throw new IllegalStateException(e);
}
if (countingOut.getCount() == digest.getSizeBytes()) {
sawFinalChunk.countDown();
delayFinalChunk.await();
assertThat(request.getFinishWrite()).isTrue();
} else {
assertThat(request.getFinishWrite()).isFalse();
}
} catch (Throwable t) {
serviceError.set(t);
responseObserver.onError(Status.INTERNAL.withCause(t).asRuntimeException());
}
}

@Override
public void onCompleted() {
responseObserver.onNext(WriteResponse.newBuilder().setCommittedSize(5).build());
responseObserver.onNext(
WriteResponse.newBuilder().setCommittedSize(digest.getSizeBytes()).build());
responseObserver.onCompleted();
}

@Override
public void onError(Throwable t) {
fail("An error occurred: " + t);
serviceError.set(t);
}
};
}
});

// Upload all missing inputs (that is, the virtual action input from above)
client.ensureInputsPresent(
context, merkleTree, ImmutableMap.of(), /* force= */ true, remotePathResolver);
System.gc();
var usedMemoryBefore = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();

var uploadError = new AtomicReference<Throwable>();
var uploadThread =
Thread.ofPlatform()
.start(
() -> {
try {
client.ensureInputsPresent(
context,
merkleTree,
ImmutableMap.of(),
/* force= */ true,
remotePathResolver);
} catch (Throwable e) {
uploadError.set(e);
}
});

sawFinalChunk.await();
System.gc();
var usedMemoryAfter = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();

delayFinalChunk.countDown();
uploadThread.join();

if (uploadError.get() != null) {
throw new AssertionError(uploadError.get());
}
if (serviceError.get() != null) {
throw new AssertionError(serviceError.get());
}
assertThat(digestOut.digest()).isEqualTo(digest);
// Ensure that memory usage didn't spike by the size of the virtual input (about 64MB).
assertThat(usedMemoryAfter - usedMemoryBefore).isLessThan(10 * 1024 * 1024);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public ListenableFuture<Void> uploadActionResult(
@Override
public ListenableFuture<Void> uploadBlob(
RemoteActionExecutionContext context, Digest digest, Blob blob) {
try (blob) {
try {
cas.put(digest, blob.get().readAllBytes());
} catch (IOException e) {
return Futures.immediateFailedFuture(e);
Expand Down
Loading