Skip to content

Commit

Permalink
Add support for wrapping system streams in WorkRequestHandler
Browse files Browse the repository at this point in the history
There are often [places](https://github.com/bazelbuild/bazel/blob/ea19c17075478092eb77580e6d3825d480126d3a/src/tools/android/java/com/google/devtools/build/android/ResourceProcessorBusyBox.java#L188) where persistent workers need to swap out the standard system streams to avoid tools poisoning the worker communication streams by writing logs/exceptions to it.

This pull request extracts that pattern into an optional WorkerIO wrapper can be used to swap in and out the standard streams without the added boilerplate.

Closes bazelbuild#14201.

PiperOrigin-RevId: 498983983
Change-Id: Iefb956d38a5887d9e5bbf0821551eb0efa14fce9
  • Loading branch information
Bencodes authored and copybara-github committed Jan 2, 2023
1 parent fd93888 commit e05345d
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import com.google.devtools.build.lib.worker.WorkerProtocol.WorkResponse;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.sun.management.OperatingSystemMXBean;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -317,8 +321,17 @@ public WorkRequestHandler build() {
* then writing the corresponding {@link WorkResponse} to {@code out}. If there is an error
* reading or writing the requests or responses, it writes an error message on {@code err} and
* returns. If {@code in} reaches EOF, it also returns.
*
* <p>This function also wraps the system streams in a {@link WorkerIO} instance that prevents the
* underlying tool from writing to {@link System#out} or reading from {@link System#in}, which
* would corrupt the worker worker protocol. When the while loop exits, the original system
* streams will be swapped back into {@link System}.
*/
public void processRequests() throws IOException {
// Wrap the system streams into a WorkerIO instance to prevent unexpected reads and writes on
// stdin/stdout.
WorkerIO workerIO = WorkerIO.capture();

try {
while (!shutdownWorker.get()) {
WorkRequest request = messageProcessor.readWorkRequest();
Expand All @@ -328,31 +341,39 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
}
} catch (IOException e) {
stderr.println("Error reading next WorkRequest: " + e);
e.printStackTrace(stderr);
}
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
} finally {
// TODO(b/220878242): Give the outstanding requests a chance to send a "shutdown" response,
// but also try to kill stuck threads. For now, we just interrupt the remaining threads.
// We considered doing System.exit here, but that is hard to test and would deny the callers
// of this method a chance to clean up. Instead, we initiate the cleanup of our resources here
// and the caller can decide whether to wait for an orderly shutdown or now.
for (RequestInfo ri : activeRequests.values()) {
if (ri.thread.isAlive()) {
try {
ri.thread.interrupt();
} catch (RuntimeException e) {
// If we can't interrupt, we can't do much else.
}
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
stderr.println(e.getMessage());
}
}
}

/** Starts a thread for the given request. */
void startResponseThread(WorkRequest request) {
void startResponseThread(WorkerIO workerIO, WorkRequest request) {
Thread currentThread = Thread.currentThread();
String threadName =
request.getRequestId() > 0
Expand Down Expand Up @@ -381,7 +402,7 @@ void startResponseThread(WorkRequest request) {
return;
}
try {
respondToRequest(request, requestInfo);
respondToRequest(workerIO, request, requestInfo);
} catch (IOException e) {
// IOExceptions here means a problem talking to the server, so we must shut down.
if (!shutdownWorker.compareAndSet(false, true)) {
Expand Down Expand Up @@ -419,7 +440,8 @@ void startResponseThread(WorkRequest request) {
* #callback} are reported with exit code 1.
*/
@VisibleForTesting
void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOException {
void respondToRequest(WorkerIO workerIO, WorkRequest request, RequestInfo requestInfo)
throws IOException {
int exitCode;
StringWriter sw = new StringWriter();
try (PrintWriter pw = new PrintWriter(sw)) {
Expand All @@ -431,6 +453,16 @@ void respondToRequest(WorkRequest request, RequestInfo requestInfo) throws IOExc
e.printStackTrace(pw);
exitCode = 1;
}

try {
// Read out the captured string for the final WorkResponse output
String captured = workerIO.readCapturedAsUtf8String().trim();
if (!captured.isEmpty()) {
pw.write(captured);
}
} catch (IOException e) {
stderr.println(e.getMessage());
}
}
Optional<WorkResponse.Builder> optBuilder = requestInfo.takeBuilder();
if (optBuilder.isPresent()) {
Expand Down Expand Up @@ -541,4 +573,104 @@ private void maybePerformGc() {
}
}
}

/**
* Class that wraps the standard {@link System#in}, {@link System#out}, and {@link System#err}
* with our own ByteArrayOutputStream that allows {@link WorkRequestHandler} to safely capture
* outputs that can't be directly captured by the PrintStream associated with the work request.
*
* <p>This is most useful when integrating JVM tools that write exceptions and logs directly to
* {@link System#out} and {@link System#err}, which would corrupt the persistent worker protocol.
* We also redirect {@link System#in}, just in case a tool should attempt to read it.
*
* <p>WorkerIO implements {@link AutoCloseable} and will swap the original streams back into
* {@link System} once close has been called.
*/
public static class WorkerIO implements AutoCloseable {
private final InputStream originalInputStream;
private final PrintStream originalOutputStream;
private final PrintStream originalErrorStream;
private final ByteArrayOutputStream capturedStream;
private final AutoCloseable restore;

/**
* Creates a new {@link WorkerIO} that allows {@link WorkRequestHandler} to capture standard
* output and error streams that can't be directly captured by the PrintStream associated with
* the work request.
*/
@VisibleForTesting
WorkerIO(
InputStream originalInputStream,
PrintStream originalOutputStream,
PrintStream originalErrorStream,
ByteArrayOutputStream capturedStream,
AutoCloseable restore) {
this.originalInputStream = originalInputStream;
this.originalOutputStream = originalOutputStream;
this.originalErrorStream = originalErrorStream;
this.capturedStream = capturedStream;
this.restore = restore;
}

/** Wraps the standard System streams and WorkerIO instance */
public static WorkerIO capture() {
// Save the original streams
InputStream originalInputStream = System.in;
PrintStream originalOutputStream = System.out;
PrintStream originalErrorStream = System.err;

// Replace the original streams with our own instances
ByteArrayOutputStream capturedStream = new ByteArrayOutputStream();
PrintStream outputBuffer = new PrintStream(capturedStream, true);
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(new byte[0]);
System.setIn(byteArrayInputStream);
System.setOut(outputBuffer);
System.setErr(outputBuffer);

return new WorkerIO(
originalInputStream,
originalOutputStream,
originalErrorStream,
capturedStream,
() -> {
System.setIn(originalInputStream);
System.setOut(originalOutputStream);
System.setErr(originalErrorStream);
outputBuffer.close();
byteArrayInputStream.close();
});
}

/** Returns the original input stream most commonly provided by {@link System#in} */
@VisibleForTesting
InputStream getOriginalInputStream() {
return originalInputStream;
}

/** Returns the original output stream most commonly provided by {@link System#out} */
@VisibleForTesting
PrintStream getOriginalOutputStream() {
return originalOutputStream;
}

/** Returns the original error stream most commonly provided by {@link System#err} */
@VisibleForTesting
PrintStream getOriginalErrorStream() {
return originalErrorStream;
}

/** Returns the captured outputs as a UTF-8 string */
@VisibleForTesting
String readCapturedAsUtf8String() throws IOException {
capturedStream.flush();
String captureOutput = capturedStream.toString(StandardCharsets.UTF_8);
capturedStream.reset();
return captureOutput;
}

@Override
public void close() throws Exception {
restore.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public final class ExampleWorker {
static final Pattern FLAG_FILE_PATTERN = Pattern.compile("(?:@|--?flagfile=)(.+)");

// A UUID that uniquely identifies this running worker process.
static final UUID workerUuid = UUID.randomUUID();
static final UUID WORKER_UUID = UUID.randomUUID();

// A counter that increases with each work unit processed.
static int workUnitCounter = 1;
Expand Down Expand Up @@ -83,6 +83,9 @@ private static class InterruptableWorkRequestHandler extends WorkRequestHandler
@Override
@SuppressWarnings("SystemExitOutsideMain")
public void processRequests() throws IOException {
ByteArrayOutputStream captured = new ByteArrayOutputStream();
WorkerIO workerIO = new WorkerIO(System.in, System.out, System.err, captured, captured);

while (true) {
WorkRequest request = messageProcessor.readWorkRequest();
if (request == null) {
Expand All @@ -100,12 +103,19 @@ public void processRequests() throws IOException {
if (request.getCancel()) {
respondToCancelRequest(request);
} else {
startResponseThread(request);
startResponseThread(workerIO, request);
}
if (workerOptions.exitAfter > 0 && workUnitCounter > workerOptions.exitAfter) {
System.exit(0);
}
}

try {
// Unwrap the system streams placing the original streams back
workerIO.close();
} catch (Exception e) {
workerIO.getOriginalErrorStream().println(e.getMessage());
}
}
}

Expand Down Expand Up @@ -241,7 +251,7 @@ private static void parseOptionsAndLog(List<String> args) throws Exception {
List<String> outputs = new ArrayList<>();

if (options.writeUUID) {
outputs.add("UUID " + workerUuid);
outputs.add("UUID " + WORKER_UUID);
}

if (options.writeCounter) {
Expand Down
Loading

0 comments on commit e05345d

Please sign in to comment.