Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DirectPath calls do not duplicate the request metadata #3663

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import io.grpc.Deadline;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.auth.MoreCallCredentials;
import java.io.IOException;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -185,21 +184,8 @@ public GrpcCallContext nullToSelf(ApiCallContext inputContext) {

@Override
public GrpcCallContext withCredentials(Credentials newCredentials) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the one main concern about this is that this is a behavior breaking change if there are uses that manually create a GrpcCallContext and use withCredentials. I don't think there are really any valid uses cases for this, but I will need to check for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes sense.

What's also concerning is, by looking at the CI failures, it seems many tests have been building the channel from InstantiatingGrpcChannelProvider without actually giving it a credentials. And MoreCallCredentials.from() throws exception for null credentials.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those showcase tests seems to be fixable within this PR. But I'm unsure about the downstream compatibility tests and those on Cloud Build (I lack permission to view logs).

Copy link
Contributor

@lqiu96 lqiu96 Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The showcase tests were following this: https://github.com/googleapis/gapic-showcase?tab=readme-ov-file#example-for-java-grpc

Essentially it's a local server that is not expected to do any authentication. I think we just need a way to be able to hit the local server with no credentials provided.

The downstream tests for our partner teams is a bit concerning. I think we'll need to investigate them to see if they're just a testing configuration issue or a logic issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially it's a local server that is not expected to do any authentication. I think we just need a way to be able to hit the local server with no credentials provided.

Would it make sense if we allow null credentials in the case of one-way TLS? I believe there are no valid use cases to go without credentials here when mTLS or DirectPath are used.

The downstream tests for our partner teams is a bit concerning. I think we'll need to investigate them to see if they're just a testing configuration issue or a deeper issue.

Agreed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize a tricky thing we missed earlier.

The API setChannelConfigurator takes an option that gets applied to the channel builder. In various tests (downstream or showcase), ManagedChannelBuilder::usePlainText is given. While users are not supposed to use insecure channels in production, this is reasonable and often convenient in tests when the server is turned on locally. This change is breaking that behavior.

Basically, either we have to change all tests (setting up TLS can be difficult), or we need a way to allow insecure channels w/ credentials, and TlsChannelCredentials.create() doesn't support that. For the second approach, IIUC, it goes back to the need of supporting adding call credentials to a managed channel builder that is possibly configured to build insecure channels. cc: @ejona86

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using InsecureChannelCredentials.create() instead of TlsChannelCredentials would be fine for those specific tests. But the existing sdk API works poorly for that.

You can add CallCredentials per-RPC, either on the stub, in the CallOptions, or with an interceptor (that adds it to the CallOptions).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes unfortunately the existing SDK API doesn't take InsecureChannelCredentials explicilty.

The client SDK has been attaching call credentials to the CallOptions. Now the problem is that DirectPath (via GoogleDefaultChannelCredentials) gets or creates its own call credentials. So currently, this is causing the same request metadata to appear twice. Because we want to adopt bound credentials in DirectPath, it's not ideal to change this GoogleDefaultChannelCredentials to build without call credentials. Nor would it be good for backward compatibility reasons.

So as this PR is trying to do, we want to stop attaching call credentials to CallOptions in all gRPC cases. But then we need to make sure all gRPC cases can get the call credentials from the channel. ManagedChannelBuilder doesn't allow it. This is where we are stuck.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an interceptor to the channel that attaches CallCreds to the CallOptions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an interceptor to the channel that attaches CallCreds to the CallOptions.

Ah. I didn't realize this from your first comment. Will try it. Thanks!

Preconditions.checkNotNull(newCredentials);
CallCredentials callCredentials = MoreCallCredentials.from(newCredentials);
return new GrpcCallContext(
channel,
newCredentials,
callOptions.withCallCredentials(callCredentials),
timeout,
streamWaitTimeout,
streamIdleTimeout,
channelAffinity,
extraHeaders,
options,
retrySettings,
retryableCodes,
endpointContext);
// Credentials will be attached to the gRPC transport channels.
return this;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using channel credentials created via DCA.
channelCredentials =
CompositeChannelCredentials.create(
channelCredentials, MoreCallCredentials.from(credentials));
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Could not create channel credentials via DCA. In accordance with
Expand All @@ -665,11 +668,18 @@ private ManagedChannel createSingleChannel() throws IOException {
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
channelCredentials =
CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials);
} else {
channelCredentials =
CompositeChannelCredentials.create(
channelCredentials, MoreCallCredentials.from(credentials));
}
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
// Use default TLS credentials if we cannot initialize channel credentials via DCA or S2A.
channelCredentials =
CompositeChannelCredentials.create(
TlsChannelCredentials.create(), MoreCallCredentials.from(credentials));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, for this flow, shouldn't it be just TlsChannelCredentials.create() for the channelCredentials?

Why does it need to be done via CompositeChannelCredentials?

Copy link
Contributor

@lqiu96 lqiu96 Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait.. This is because we were sending the GoogleCredentials as part of CallOptions. NVM I think this is correct.

builder = Grpc.newChannelBuilderForAddress(serviceAddress, port, channelCredentials);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void testWithCredentials() {
GrpcCallContext emptyContext = GrpcCallContext.createDefault();
assertNull(emptyContext.getCallOptions().getCredentials());
GrpcCallContext context = emptyContext.withCredentials(credentials);
assertNotNull(context.getCallOptions().getCredentials());
assertNull(context.getCallOptions().getCredentials());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
package com.google.api.gax.grpc;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.mockito.Mockito.verify;

import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.rpc.EndpointContext;
import com.google.api.gax.rpc.UnauthenticatedException;
import com.google.api.gax.rpc.UnavailableException;
import com.google.auth.Credentials;
import com.google.auth.Retryable;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -310,11 +309,8 @@ void testInvalidUniverseDomain() throws IOException {
CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
UnauthenticatedException exception =
assertThrows(
UnauthenticatedException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAUTHENTICATED);
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
assertDoesNotThrow(() -> GrpcClientCalls.newCall(descriptor, context));
Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions);
}

// This test is when the MDS is unable to return a valid universe domain
Expand All @@ -332,11 +328,7 @@ void testUniverseDomainNotReady_shouldRetry() throws IOException {
CallOptions callOptions = context.getCallOptions();

MethodDescriptor<Color, Money> descriptor = FakeServiceGrpc.METHOD_RECOGNIZE;
UnavailableException exception =
assertThrows(
UnavailableException.class, () -> GrpcClientCalls.newCall(descriptor, context));
assertThat(exception.getStatusCode().getCode()).isEqualTo(GrpcStatusCode.Code.UNAVAILABLE);
Truth.assertThat(exception.isRetryable()).isTrue();
Mockito.verify(mockChannel, Mockito.never()).newCall(descriptor, callOptions);
assertDoesNotThrow(() -> GrpcClientCalls.newCall(descriptor, context));
Mockito.verify(mockChannel, Mockito.times(1)).newCall(descriptor, callOptions);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ void testWithPoolSize() throws IOException {
ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
executor.shutdown();

Credentials credentials = Mockito.mock(Credentials.class);

TransportChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setCredentials(credentials)
.build()
.withExecutor((Executor) executor)
.withHeaders(Collections.<String, String>emptyMap())
Expand All @@ -234,6 +237,7 @@ void testToBuilder() {
new ArrayList<>();
hardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS);
hardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A);
Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
Expand All @@ -244,6 +248,7 @@ void testToBuilder() {
.setKeepAliveTimeDuration(keepaliveTime)
.setKeepAliveTimeoutDuration(keepaliveTimeout)
.setKeepAliveWithoutCalls(Boolean.TRUE)
.setCredentials(credentials)
.setChannelConfigurator(channelConfigurator)
.setChannelsPerCpu(2.5)
.setDirectPathServiceConfig(directPathServiceConfig)
Expand Down Expand Up @@ -274,6 +279,7 @@ void testWithInterceptorsAndMultipleChannels() throws Exception {

private void testWithInterceptors(int numChannels) throws Exception {
final GrpcInterceptorProvider interceptorProvider = Mockito.mock(GrpcInterceptorProvider.class);
Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider channelProvider =
InstantiatingGrpcChannelProvider.newBuilder()
Expand All @@ -282,6 +288,7 @@ private void testWithInterceptors(int numChannels) throws Exception {
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.setInterceptorProvider(interceptorProvider)
.setCredentials(credentials)
.build();

Mockito.verify(interceptorProvider, Mockito.never()).getInterceptors();
Expand All @@ -303,6 +310,7 @@ void testChannelConfigurator() throws IOException {

ManagedChannelBuilder<?> swappedBuilder = Mockito.mock(ManagedChannelBuilder.class);
ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
Credentials credentials = Mockito.mock(Credentials.class);
Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel);

Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture()))
Expand All @@ -315,6 +323,7 @@ void testChannelConfigurator() throws IOException {
.setExecutor(Mockito.mock(Executor.class))
.setChannelConfigurator(channelConfigurator)
.setPoolSize(numChannels)
.setCredentials(credentials)
.build()
.getTransportChannel();

Expand Down Expand Up @@ -486,8 +495,11 @@ void testWithIPv6Address() throws IOException {
ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
executor.shutdown();

Credentials credentials = Mockito.mock(Credentials.class);

TransportChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setCredentials(credentials)
.build()
.withExecutor((Executor) executor)
.withHeaders(Collections.<String, String>emptyMap())
Expand All @@ -501,6 +513,8 @@ void testWithIPv6Address() throws IOException {
// Test that if ChannelPrimer is provided, it is called during creation
@Test
void testWithPrimeChannel() throws IOException {
Credentials credentials = Mockito.mock(Credentials.class);

// create channelProvider with different pool sizes to verify ChannelPrimer is called the
// correct number of times
for (int poolSize = 1; poolSize < 5; poolSize++) {
Expand All @@ -512,6 +526,7 @@ void testWithPrimeChannel() throws IOException {
.setPoolSize(poolSize)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.setCredentials(credentials)
.setChannelPrimer(mockChannelPrimer)
.build();

Expand Down Expand Up @@ -600,10 +615,14 @@ void testWithCustomDirectPathServiceConfig() {
@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
throws IOException, GeneralSecurityException {

Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider channelProvider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setCredentials(credentials)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build();
Expand All @@ -630,9 +649,14 @@ private void createAndCloseTransportChannel(InstantiatingGrpcChannelProvider pro
testLogDirectPathMisconfig_AttemptDirectPathNotSetAndAttemptDirectPathXdsSetViaBuilder_warns()
throws Exception {
FakeLogHandler logHandler = new FakeLogHandler();
Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler);
InstantiatingGrpcChannelProvider provider =
createChannelProviderBuilderForDirectPathLogTests().setAttemptDirectPathXds().build();
createChannelProviderBuilderForDirectPathLogTests()
.setAttemptDirectPathXds()
.setCredentials(credentials)
.build();
createAndCloseTransportChannel(provider);
assertThat(logHandler.getAllMessages())
.contains(
Expand Down Expand Up @@ -672,11 +696,14 @@ void testLogDirectPathMisconfig_shouldNotLogInTheBuilder() {
@Test
void testLogDirectPathMisconfigWrongCredential() throws Exception {
FakeLogHandler logHandler = new FakeLogHandler();
Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler);
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setAttemptDirectPathXds()
.setAttemptDirectPath(true)
.setCredentials(credentials)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.setEndpoint(DEFAULT_ENDPOINT)
Expand All @@ -697,11 +724,14 @@ void testLogDirectPathMisconfigWrongCredential() throws Exception {
@Test
void testLogDirectPathMisconfigNotOnGCE() throws Exception {
FakeLogHandler logHandler = new FakeLogHandler();
Credentials credentials = Mockito.mock(Credentials.class);

InstantiatingGrpcChannelProvider.LOG.addHandler(logHandler);
InstantiatingGrpcChannelProvider provider =
InstantiatingGrpcChannelProvider.newBuilder()
.setAttemptDirectPathXds()
.setAttemptDirectPath(true)
.setCredentials(credentials)
.setAllowNonDefaultServiceAccount(true)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
Expand Down
Loading