Skip to content

Commit

Permalink
feat: Validate the Universe Domain (#2330)
Browse files Browse the repository at this point in the history
* feat: Validate the universe domain

* chore: Merge in from origin/main

* chore: Add comments for ApiCallContext

* chore: Add comments

* chore: Address PR comments

* chore: Merge endpoint context in both transports

* chore: Use @throws for the exceptions

* chore: Provide a default EndpointContext

* chore: Address PR comments

* chore: Update error message

* chore: Address PR comments

* chore: Address PR comments

* chore: Address PR comments
  • Loading branch information
lqiu96 authored Jan 12, 2024
1 parent c3d1142 commit 097bc93
Show file tree
Hide file tree
Showing 23 changed files with 822 additions and 120 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public static <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
channel = ClientInterceptors.intercept(channel, interceptor);
}

// Validate the Universe Domain prior to the call. Only allow the call to go through
// if the Universe Domain is valid.
grpcContext.validateUniverseDomain();

try (Scope ignored = grpcContext.getTracer().inScope()) {
return channel.newCall(descriptor, callOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.ResponseObserver;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.UnaryCallable;
import com.google.api.gax.util.FakeLogHandler;
import com.google.auth.Credentials;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -628,10 +630,17 @@ public void testReleasingClientCallCancelEarly() throws IOException {
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
pool = ChannelPool.create(channelPoolSettings, factory);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.doNothing()
.when(endpointContext)
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));

ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(pool, CallOptions.DEFAULT).withEndpointContext(endpointContext))
.build();
ServerStreamingCallSettings settings =
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
Expand Down Expand Up @@ -680,11 +689,19 @@ public void testDoubleRelease() throws Exception {

pool = ChannelPool.create(channelPoolSettings, factory);

EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.doNothing()
.when(endpointContext)
.validateUniverseDomain(
Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));

// Construct a fake callable to use the channel pool
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(pool, CallOptions.DEFAULT)
.withEndpointContext(endpointContext))
.build();

UnaryCallSettings<Color, Money> settings =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import io.grpc.CallOptions;
import io.grpc.ManagedChannel;
import io.grpc.Metadata.Key;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -373,7 +374,7 @@ public void testWithOptions() {
}

@Test
public void testMergeOptions() {
public void testMergeOptions() throws IOException {
GrpcCallContext emptyCallContext = GrpcCallContext.createDefault();
ApiCallContext.Key<String> contextKey1 = ApiCallContext.Key.create("testKey1");
ApiCallContext.Key<String> contextKey2 = ApiCallContext.Key.create("testKey2");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@
import com.google.api.gax.grpc.testing.FakeServiceImpl;
import com.google.api.gax.grpc.testing.InProcessServer;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.ApiCallContext;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.InvalidArgumentException;
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StatusCode.Code;
import com.google.api.gax.tracing.SpanName;
import com.google.auth.Credentials;
import com.google.common.collect.ImmutableList;
import com.google.common.truth.Truth;
import com.google.type.Color;
Expand Down Expand Up @@ -74,10 +77,16 @@ public void setUp() throws Exception {
inprocessServer.start();

channel = InProcessChannelBuilder.forName(serverName).directExecutor().usePlaintext().build();
EndpointContext endpointContext = Mockito.mock(EndpointContext.class);
Mockito.doNothing()
.when(endpointContext)
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
clientContext =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channel))
.setDefaultCallContext(GrpcCallContext.of(channel, CallOptions.DEFAULT))
.setDefaultCallContext(
GrpcCallContext.of(channel, CallOptions.DEFAULT)
.withEndpointContext(endpointContext))
.build();
}

Expand Down Expand Up @@ -106,11 +115,10 @@ public void createServerStreamingCallableRetryableExceptions() {
GrpcCallableFactory.createServerStreamingCallable(
grpcCallSettings, nonRetryableSettings, clientContext);

ApiCallContext defaultCallContext = clientContext.getDefaultCallContext();
Throwable actualError = null;
try {
nonRetryableCallable
.first()
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
nonRetryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
} catch (Throwable e) {
actualError = e;
}
Expand All @@ -134,9 +142,7 @@ public void createServerStreamingCallableRetryableExceptions() {

Throwable actualError2 = null;
try {
retryableCallable
.first()
.call(Color.getDefaultInstance(), clientContext.getDefaultCallContext());
retryableCallable.first().call(Color.getDefaultInstance(), defaultCallContext);
} catch (Throwable e) {
actualError2 = e;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,16 @@
package com.google.api.gax.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.verify;

import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.UnauthenticatedException;
import com.google.api.gax.rpc.UnavailableException;
import com.google.auth.Credentials;
import com.google.auth.Retryable;
import com.google.common.collect.ImmutableList;
import com.google.common.truth.Truth;
import com.google.type.Color;
Expand All @@ -45,18 +51,58 @@
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.threeten.bp.Duration;

public class GrpcClientCallsTest {

// Auth Library's GoogleAuthException is package-private. Copy basic functionality for tests
private static class GoogleAuthException extends IOException implements Retryable {

private final boolean isRetryable;

private GoogleAuthException(boolean isRetryable) {
this.isRetryable = isRetryable;
}

@Override
public boolean isRetryable() {
return isRetryable;
}

@Override
public int getRetryCount() {
return 0;
}
}

private GrpcCallContext defaultCallContext;
private EndpointContext endpointContext;
private Credentials credentials;
private Channel mockChannel;

@Before
public void setUp() throws IOException {
credentials = Mockito.mock(Credentials.class);
endpointContext = Mockito.mock(EndpointContext.class);
mockChannel = Mockito.mock(Channel.class);

defaultCallContext = GrpcCallContext.createDefault().withEndpointContext(endpointContext);
Mockito.doNothing()
.when(endpointContext)
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
}

@Test
public void testAffinity() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
Expand All @@ -78,7 +124,7 @@ public void testAffinity() throws IOException {
ChannelPool.create(
ChannelPoolSettings.staticallySized(2),
new FakeChannelFactory(Arrays.asList(channel0, channel1)));
GrpcCallContext context = GrpcCallContext.createDefault().withChannel(pool);
GrpcCallContext context = defaultCallContext.withChannel(pool);

ClientCall<Color, Money> gotCallA =
GrpcClientCalls.newCall(descriptor, context.withChannelAffinity(0));
Expand All @@ -92,7 +138,7 @@ public void testAffinity() throws IOException {
}

@Test
public void testExtraHeaders() {
public void testExtraHeaders() throws IOException {
Metadata emptyHeaders = new Metadata();
final Map<String, List<String>> extraHeaders = new HashMap<>();
extraHeaders.put(
Expand Down Expand Up @@ -128,12 +174,12 @@ public void testExtraHeaders() {
.thenReturn(mockClientCall);

GrpcCallContext context =
GrpcCallContext.createDefault().withChannel(mockChannel).withExtraHeaders(extraHeaders);
defaultCallContext.withChannel(mockChannel).withExtraHeaders(extraHeaders);
GrpcClientCalls.newCall(descriptor, context).start(mockListener, emptyHeaders);
}

@Test
public void testTimeoutToDeadlineConversion() {
public void testTimeoutToDeadlineConversion() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -152,8 +198,7 @@ public void testTimeoutToDeadlineConversion() {
Duration timeout = Duration.ofSeconds(10);
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);

GrpcCallContext context =
GrpcCallContext.createDefault().withChannel(mockChannel).withTimeout(timeout);
GrpcCallContext context = defaultCallContext.withChannel(mockChannel).withTimeout(timeout);

GrpcClientCalls.newCall(descriptor, context).start(mockListener, new Metadata());

Expand All @@ -164,7 +209,7 @@ public void testTimeoutToDeadlineConversion() {
}

@Test
public void testTimeoutAfterDeadline() {
public void testTimeoutAfterDeadline() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -185,7 +230,7 @@ public void testTimeoutAfterDeadline() {
Duration timeout = Duration.ofSeconds(10);

GrpcCallContext context =
GrpcCallContext.createDefault()
defaultCallContext
.withChannel(mockChannel)
.withCallOptions(CallOptions.DEFAULT.withDeadline(priorDeadline))
.withTimeout(timeout);
Expand All @@ -197,7 +242,7 @@ public void testTimeoutAfterDeadline() {
}

@Test
public void testTimeoutBeforeDeadline() {
public void testTimeoutBeforeDeadline() throws IOException {
MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;

@SuppressWarnings("unchecked")
Expand All @@ -219,7 +264,7 @@ public void testTimeoutBeforeDeadline() {
Deadline minExpectedDeadline = Deadline.after(timeout.getSeconds(), TimeUnit.SECONDS);

GrpcCallContext context =
GrpcCallContext.createDefault()
defaultCallContext
.withChannel(mockChannel)
.withCallOptions(CallOptions.DEFAULT.withDeadline(subsequentDeadline))
.withTimeout(timeout);
Expand All @@ -232,4 +277,66 @@ public void testTimeoutBeforeDeadline() {
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtLeast(minExpectedDeadline);
Truth.assertThat(capturedCallOptions.getValue().getDeadline()).isAtMost(maxExpectedDeadline);
}

@Test
public void testValidUniverseDomain() throws IOException {
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
GrpcClientCalls.newCall(descriptor, context);
Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions);
}

// This test is when the universe domain does not match
@Test
public void testInvalidUniverseDomain() throws IOException {
Mockito.doThrow(
new UnauthenticatedException(
null, GrpcStatusCode.of(Status.Code.UNAUTHENTICATED), false))
.when(endpointContext)
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
UnauthenticatedException exception =
assertThrows(
UnauthenticatedException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAUTHENTICATED);
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
}

// This test is when the MDS is unable to return a valid universe domain
@Test
public void testUniverseDomainNotReady_shouldRetry() throws IOException {
Mockito.doThrow(new GoogleAuthException(true))
.when(endpointContext)
.validateUniverseDomain(Mockito.any(Credentials.class), Mockito.any(GrpcStatusCode.class));
GrpcCallContext context =
GrpcCallContext.createDefault()
.withChannel(mockChannel)
.withCredentials(credentials)
.withEndpointContext(endpointContext);

CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
UnavailableException exception =
assertThrows(
UnavailableException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAVAILABLE);
Truth.assertThat(exception.isRetryable()).isTrue();
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
}
}
Loading

0 comments on commit 097bc93

Please sign in to comment.