Skip to content

Commit

Permalink
Implement retries for listPlugins request to allow language servers t…
Browse files Browse the repository at this point in the history
…o ramp up.

PiperOrigin-RevId: 605588300
Change-Id: I8fc52f758d13f53d225f185085af34375b888e73
  • Loading branch information
nttran8 authored and copybara-github committed Feb 9, 2024
1 parent e647c37 commit 84f2f50
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 28 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ subprojects {
errorproneVersion = '2.4.0'
floggerVersion = '0.5.1'
googleCloudStorageVersion = '1.103.1'
googleHttpClientVersion = '1.44.1'
guavaVersion = '28.2-jre'
guiceVersion = '4.2.3'
grpcVersion = '1.60.0'
Expand Down Expand Up @@ -60,6 +61,7 @@ subprojects {
flogger_google_ext: "com.google.flogger:google-extensions:${floggerVersion}",
flogger_backend: "com.google.flogger:flogger-system-backend:${floggerVersion}",
google_cloud_storage: "com.google.cloud:google-cloud-storage:${googleCloudStorageVersion}",
google_http_client: "com.google.http-client:google-http-client:${googleHttpClientVersion}",
guava: "com.google.guava:guava:${guavaVersion}",
guice: "com.google.inject:guice:${guiceVersion}",
guice_assisted: "com.google.inject.extensions:guice-assistedinject:${guiceVersion}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import static com.google.common.base.Preconditions.checkNotNull;
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.api.client.util.BackOff;
import com.google.api.client.util.ExponentialBackOff;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.google.common.flogger.GoogleLogger;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.tsunami.proto.DetectionReportList;
import com.google.tsunami.proto.ListPluginsRequest;
import com.google.tsunami.proto.MatchedPlugin;
Expand All @@ -32,16 +35,29 @@
import io.grpc.Deadline;
import io.grpc.health.v1.HealthCheckRequest;
import io.grpc.health.v1.HealthCheckResponse;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

final class RemoteVulnDetectorImpl implements RemoteVulnDetector {
/** Facilitates communication with remote detectors. */
public final class RemoteVulnDetectorImpl implements RemoteVulnDetector {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
// Default duration deadline for all RPC calls
// Remote detectors, especially ones using the callback server, require additional buffer to send
// requests and responses.
private static final Deadline DEFAULT_DEADLINE = Deadline.after(150, SECONDS);

private static final int INITIAL_WAIT_TIME_MS = 200;
private static final int MAX_WAIT_TIME_MS = 30000;
private static final int WAIT_TIME_MULTIPLIER = 3;
private static final int MAX_ATTEMPTS = 3;
private final ExponentialBackOff backoff =
new ExponentialBackOff.Builder()
.setInitialIntervalMillis(INITIAL_WAIT_TIME_MS)
.setRandomizationFactor(0.1)
.setMultiplier(WAIT_TIME_MULTIPLIER)
.setMaxElapsedTimeMillis(MAX_WAIT_TIME_MS)
.build();
private final PluginServiceClient service;
private final Set<MatchedPlugin> pluginsToRun;

Expand All @@ -54,21 +70,14 @@ final class RemoteVulnDetectorImpl implements RemoteVulnDetector {
public DetectionReportList detect(
TargetInfo target, ImmutableList<NetworkService> matchedServices) {
try {
if (service
.checkHealthWithDeadline(HealthCheckRequest.getDefaultInstance(), DEFAULT_DEADLINE)
.get()
.getStatus()
.equals(HealthCheckResponse.ServingStatus.SERVING)) {
if (checkHealthWithBackoffs()) {
logger.atInfo().log("Detecting with language server plugins...");
return service
.runWithDeadline(
RunRequest.newBuilder().setTarget(target).addAllPlugins(pluginsToRun).build(),
DEFAULT_DEADLINE)
.get()
.getReports();
} else {
logger.atWarning().log(
"Server health status is not SERVING. Will not run matched plugins.");
}
} catch (InterruptedException | ExecutionException e) {
throw new LanguageServerException("Failed to get response from language server.", e);
Expand All @@ -79,24 +88,57 @@ public DetectionReportList detect(
@Override
public ImmutableList<PluginDefinition> getAllPlugins() {
try {
if (service
.checkHealthWithDeadline(HealthCheckRequest.getDefaultInstance(), DEFAULT_DEADLINE)
.get()
.getStatus()
.equals(HealthCheckResponse.ServingStatus.SERVING)) {
if (checkHealthWithBackoffs()) {
logger.atInfo().log("Getting language server plugins...");
return ImmutableList.copyOf(
service
.listPluginsWithDeadline(ListPluginsRequest.getDefaultInstance(), DEFAULT_DEADLINE)
.get()
.getPluginsList());
} else {
logger.atWarning().log("Server health status is not SERVING. Will not retrieve plugins.");
return ImmutableList.of();
}
} catch (InterruptedException | ExecutionException e) {
throw new LanguageServerException("Failed to get plugins from language server.", e);
throw new LanguageServerException("Failed to get response from language server.", e);
}
}

private boolean checkHealthWithBackoffs() {
// After starting the language server, this is our first attempt to establish a connection
// between the Java and the language server.
// Sometimes the language server may need longer time to ramp up its health service, so we need
// to implement exponential retries to manage those circumstances.
backoff.reset();
int attempt = 0;
while (attempt < MAX_ATTEMPTS) {
try {
var healthy =
service
.checkHealthWithDeadline(HealthCheckRequest.getDefaultInstance(), DEFAULT_DEADLINE)
.get()
.getStatus()
.equals(HealthCheckResponse.ServingStatus.SERVING);
if (!healthy) {
logger.atWarning().log("Language server is not serving.");
}
return healthy;
} catch (InterruptedException | ExecutionException e) {
attempt++;
try {
long backOffMillis = backoff.nextBackOffMillis();
if (backOffMillis != BackOff.STOP) {
Uninterruptibles.sleepUninterruptibly(backOffMillis, TimeUnit.MILLISECONDS);
}
} catch (IOException ioe) {
// ignore
logger.atWarning().log("Failed to sleep for %s", ioe.getCause().getMessage());
}
if (attempt == MAX_ATTEMPTS) {
throw new LanguageServerException("Language service is not registered.", e.getCause());
}
}
}
return ImmutableList.of();
return false;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void run(RunRequest request, StreamObserver<RunResponse> responseObserver
public void detect_withServingServer_returnsSuccessfulDetectionReportList() throws Exception {
registerHealthCheckWithStatus(ServingStatus.SERVING);
registerSuccessfulRunService();

RemoteVulnDetector pluginToTest = getNewRemoteVulnDetectorInstance();
var endpointToTest = NetworkEndpointUtils.forIpAndPort("1.1.1.1", 80);
var serviceToTest =
Expand All @@ -93,7 +93,7 @@ public void detect_withServingServer_returnsSuccessfulDetectionReportList() thro
.setTransportProtocol(TransportProtocol.TCP)
.setServiceName("http")
.build();

TargetInfo testTarget = TargetInfo.newBuilder().addNetworkEndpoints(endpointToTest).build();
pluginToTest.addMatchedPluginToDetect(
MatchedPlugin.newBuilder()
Expand All @@ -109,10 +109,10 @@ public void detect_withServingServer_returnsSuccessfulDetectionReportList() thro
.build());
}

@Test
@Test(timeout = 11000L)
public void detect_withNonServingServer_returnsEmptyDetectionReportList() throws Exception {
registerHealthCheckWithStatus(ServingStatus.NOT_SERVING);

RemoteVulnDetector pluginToTest = getNewRemoteVulnDetectorInstance();
var endpointToTest = NetworkEndpointUtils.forIpAndPort("1.1.1.1", 80);
var serviceToTest =
Expand All @@ -121,7 +121,7 @@ public void detect_withNonServingServer_returnsEmptyDetectionReportList() throws
.setTransportProtocol(TransportProtocol.TCP)
.setServiceName("http")
.build();

TargetInfo testTarget = TargetInfo.newBuilder().addNetworkEndpoints(endpointToTest).build();
pluginToTest.addMatchedPluginToDetect(
MatchedPlugin.newBuilder()
Expand All @@ -132,7 +132,7 @@ public void detect_withNonServingServer_returnsEmptyDetectionReportList() throws
.isEmpty();
}

@Test
@Test(timeout = 32000L)
public void detect_withRpcError_throwsLanguageServerException() throws Exception {
registerHealthCheckWithError();

Expand All @@ -146,7 +146,7 @@ public void detect_withRpcError_throwsLanguageServerException() throws Exception
@Test
public void getAllPlugins_withServingServer_returnsSuccessfulList() throws Exception {
registerHealthCheckWithStatus(ServingStatus.SERVING);

var plugin = createSinglePluginDefinitionWithName("test");
RemoteVulnDetector pluginToTest = getNewRemoteVulnDetectorInstance();
serviceRegistry.addService(
Expand All @@ -158,22 +158,28 @@ public void listPlugins(
responseObserver.onCompleted();
}
});

assertThat(pluginToTest.getAllPlugins()).containsExactly(plugin);
}

@Test
@Test(timeout = 32000L)
public void getAllPlugins_withNonServingServer_returnsEmptyList() throws Exception {
registerHealthCheckWithStatus(ServingStatus.NOT_SERVING);
assertThat(getNewRemoteVulnDetectorInstance().getAllPlugins()).isEmpty();
}

@Test
@Test(timeout = 32000L)
public void getAllPlugins_withRpcError_throwsLanguageServerException() throws Exception {
registerHealthCheckWithError();
assertThrows(LanguageServerException.class, getNewRemoteVulnDetectorInstance()::getAllPlugins);
}

@Test(timeout = 32000L)
public void getAllPlugins_withUnregisteredHealthService_throwsLanguageServerException()
throws Exception {
assertThrows(LanguageServerException.class, getNewRemoteVulnDetectorInstance()::getAllPlugins);
}

private RemoteVulnDetector getNewRemoteVulnDetectorInstance() throws Exception {
String serverName = InProcessServerBuilder.generateName();
grpcCleanup.register(
Expand Down

0 comments on commit 84f2f50

Please sign in to comment.