Skip to content

xds: add support for custom per-target credentials on the transport. #11951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.ChannelCredentials;
import io.grpc.ClientCall;
Expand All @@ -34,35 +35,50 @@

final class GrpcXdsTransportFactory implements XdsTransportFactory {

static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY =
new GrpcXdsTransportFactory();
private final CallCredentials callCredentials;

GrpcXdsTransportFactory(CallCredentials callCredentials) {
this.callCredentials = callCredentials;
}

@Override
public XdsTransport create(Bootstrapper.ServerInfo serverInfo) {
return new GrpcXdsTransport(serverInfo);
return new GrpcXdsTransport(serverInfo, callCredentials);
}

@VisibleForTesting
public XdsTransport createForTest(ManagedChannel channel) {
return new GrpcXdsTransport(channel);
return new GrpcXdsTransport(channel, callCredentials);
}

@VisibleForTesting
static class GrpcXdsTransport implements XdsTransport {

private final ManagedChannel channel;
private final CallCredentials callCredentials;

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) {
this(serverInfo, null);
}

Check warning on line 62 in xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java#L61-L62

Added lines #L61 - L62 were not covered by tests

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
this(channel, null);
}

public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) {
String target = serverInfo.target();
ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig();
this.channel = Grpc.newChannelBuilder(target, channelCredentials)
.keepAliveTime(5, TimeUnit.MINUTES)
.build();
this.callCredentials = callCredentials;
}

@VisibleForTesting
public GrpcXdsTransport(ManagedChannel channel) {
public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) {
this.channel = checkNotNull(channel, "channel");
this.callCredentials = callCredentials;
}

@Override
Expand All @@ -72,7 +88,8 @@
MethodDescriptor.Marshaller<RespT> respMarshaller) {
Context prevContext = Context.ROOT.attach();
try {
return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller);
return new XdsStreamingCall<>(
fullMethodName, reqMarshaller, respMarshaller, callCredentials);
} finally {
Context.ROOT.detach(prevContext);
}
Expand All @@ -89,16 +106,21 @@

private final ClientCall<ReqT, RespT> call;

public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller) {
this.call = channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT); // TODO(zivy): support waitForReady
public XdsStreamingCall(
String methodName,
MethodDescriptor.Marshaller<ReqT> reqMarshaller,
MethodDescriptor.Marshaller<RespT> respMarshaller,
CallCredentials callCredentials) {
this.call =
channel.newCall(
MethodDescriptor.<ReqT, RespT>newBuilder()
.setFullMethodName(methodName)
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setRequestMarshaller(reqMarshaller)
.setResponseMarshaller(respMarshaller)
.build(),
CallOptions.DEFAULT.withCallCredentials(
callCredentials)); // TODO(zivy): support waitForReady
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.grpc.xds;

import io.grpc.CallCredentials;
import io.grpc.Internal;
import io.grpc.MetricRecorder;
import io.grpc.internal.ObjectPool;
Expand All @@ -42,6 +43,13 @@

public static ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder);
return getOrCreate(target, metricRecorder, null);

Check warning on line 46 in xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java#L46

Added line #L46 was not covered by tests
}

public static ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
return SharedXdsClientPoolProvider.getDefaultProvider()
.getOrCreate(target, metricRecorder, transportCallCredentials);

Check warning on line 53 in xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java

View check run for this annotation

Codecov / codecov/patch

xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java#L52-L53

Added lines #L52 - L53 were not covered by tests
}
}
50 changes: 36 additions & 14 deletions xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package io.grpc.xds;

import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.grpc.CallCredentials;
import io.grpc.MetricRecorder;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.GrpcUtil;
Expand Down Expand Up @@ -87,6 +87,12 @@ public ObjectPool<XdsClient> get(String target) {
@Override
public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRecorder)
throws XdsInitializationException {
return getOrCreate(target, metricRecorder, null);
}

public ObjectPool<XdsClient> getOrCreate(
String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials)
throws XdsInitializationException {
ObjectPool<XdsClient> ref = targetToXdsClientMap.get(target);
if (ref == null) {
synchronized (lock) {
Expand All @@ -102,7 +108,9 @@ public ObjectPool<XdsClient> getOrCreate(String target, MetricRecorder metricRec
if (bootstrapInfo.servers().isEmpty()) {
throw new XdsInitializationException("No xDS server provided");
}
ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder);
ref =
new RefCountedXdsClientObjectPool(
bootstrapInfo, target, metricRecorder, transportCallCredentials);
targetToXdsClientMap.put(target, ref);
}
}
Expand All @@ -126,6 +134,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
private final BootstrapInfo bootstrapInfo;
private final String target; // The target associated with the xDS client.
private final MetricRecorder metricRecorder;
private final CallCredentials transportCallCredentials;
private final Object lock = new Object();
@GuardedBy("lock")
private ScheduledExecutorService scheduler;
Expand All @@ -137,11 +146,21 @@ class RefCountedXdsClientObjectPool implements ObjectPool<XdsClient> {
private XdsClientMetricReporterImpl metricReporter;

@VisibleForTesting
RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target,
MetricRecorder metricRecorder) {
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) {
this(bootstrapInfo, target, metricRecorder, null);
}

@VisibleForTesting
RefCountedXdsClientObjectPool(
BootstrapInfo bootstrapInfo,
String target,
MetricRecorder metricRecorder,
CallCredentials transportCallCredentials) {
this.bootstrapInfo = checkNotNull(bootstrapInfo);
this.target = target;
this.metricRecorder = metricRecorder;
this.transportCallCredentials = transportCallCredentials;
}

@Override
Expand All @@ -153,16 +172,19 @@ public XdsClient getObject() {
}
scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE);
metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target);
xdsClient = new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
GrpcXdsTransportFactory xdsTransportFactory =
new GrpcXdsTransportFactory(transportCallCredentials);
xdsClient =
new XdsClientImpl(
xdsTransportFactory,
bootstrapInfo,
scheduler,
BACKOFF_POLICY_PROVIDER,
GrpcUtil.STOPWATCH_SUPPLIER,
TimeProvider.SYSTEM_TIME_PROVIDER,
MessagePrinter.INSTANCE,
new TlsContextManagerImpl(bootstrapInfo),
metricReporter);
metricReporter.setXdsClient(xdsClient);
}
refCount++;
Expand Down
3 changes: 1 addition & 2 deletions xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;
import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
Expand Down Expand Up @@ -4193,7 +4192,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() {
private XdsClientImpl createXdsClient(String serverUri) {
BootstrapInfo bootstrapInfo = buildBootStrap(serverUri);
return new XdsClientImpl(
DEFAULT_XDS_TRANSPORT_FACTORY,
new GrpcXdsTransportFactory(null),
bootstrapInfo,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ public void onCompleted() {
@Test
public void callApis() throws Exception {
XdsTransportFactory.XdsTransport xdsTransport =
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create(
Bootstrapper.ServerInfo.create("localhost:" + server.getPort(),
InsecureChannelCredentials.create()));
new GrpcXdsTransportFactory(null)
.create(
Bootstrapper.ServerInfo.create(
"localhost:" + server.getPort(), InsecureChannelCredentials.create()));
MethodDescriptor<DiscoveryRequest, DiscoveryResponse> methodDescriptor =
AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod();
XdsTransportFactory.StreamingCall<DiscoveryRequest, DiscoveryResponse> streamingCall =
Expand Down
14 changes: 9 additions & 5 deletions xds/src/test/java/io/grpc/xds/LoadReportClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ public void cancelled(Context context) {
when(backoffPolicy2.nextBackoffNanos())
.thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L));
addFakeStatsData();
lrsClient = new LoadReportClient(loadStatsManager,
GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel),
NODE,
syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
lrsClient =
new LoadReportClient(
loadStatsManager,
new GrpcXdsTransportFactory(null).createForTest(channel),
NODE,
syncContext,
fakeClock.getScheduledExecutorService(),
backoffPolicyProvider,
fakeClock.getStopwatchSupplier());
syncContext.execute(new Runnable() {
@Override
public void run() {
Expand Down
77 changes: 77 additions & 0 deletions xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,36 @@


import static com.google.common.truth.Truth.assertThat;
import static io.grpc.Metadata.ASCII_STRING_MARSHALLER;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.OAuth2Credentials;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.Metadata;
import io.grpc.MetricRecorder;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.ObjectPool;
import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool;
import io.grpc.xds.XdsListenerResource.LdsUpdate;
import io.grpc.xds.client.Bootstrapper.BootstrapInfo;
import io.grpc.xds.client.Bootstrapper.ServerInfo;
import io.grpc.xds.client.EnvoyProtoData.Node;
import io.grpc.xds.client.XdsClient;
import io.grpc.xds.client.XdsClient.ResourceWatcher;
import io.grpc.xds.client.XdsInitializationException;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
Expand All @@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest {
private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build();
private final MetricRecorder metricRecorder = new MetricRecorder() {};
private static final String DUMMY_TARGET = "dummy";
static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY =
Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER);

@Mock
private GrpcBootstrapperImpl bootstrapper;
@Mock private ResourceWatcher<LdsUpdate> ldsResourceWatcher;

@Test
public void noServer() throws XdsInitializationException {
Expand Down Expand Up @@ -138,4 +157,62 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh
assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1);
xdsClientPool.returnObject(xdsClient2);
}

private class CallCredsServerInterceptor implements ServerInterceptor {
private SettableFuture<String> tokenFuture = SettableFuture.create();

@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> serverCall,
Metadata metadata,
ServerCallHandler<ReqT, RespT> next) {
tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY));
return next.startCall(serverCall, metadata);
}

public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception {
return tokenFuture.get(timeout, unit);
}
}

@Test
public void xdsClient_usesCallCredentials() throws Exception {
// Set up fake xDS server
XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService();
CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor();
Server xdsServer =
Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())
.addService(fakeXdsService)
.intercept(callCredentialsInterceptor)
.build()
.start();
String xdsServerUri = "localhost:" + xdsServer.getPort();

// Set up bootstrap & xDS client pool provider
ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create());
BootstrapInfo bootstrapInfo =
BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build();
when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo);
SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper);

// Create custom xDS transport CallCredentials
CallCredentials sampleCreds =
MoreCallCredentials.from(
OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null)));

// Create xDS client that uses the CallCredentials on the transport
ObjectPool<XdsClient> xdsClientPool =
provider.getOrCreate("target", metricRecorder, sampleCreds);
XdsClient xdsClient = xdsClientPool.getObject();
xdsClient.watchXdsResource(
XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher);

// Wait for xDS server to get the request and verify that it received the CallCredentials
assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS))
.isEqualTo("Bearer token");

// Clean up
xdsClientPool.returnObject(xdsClient);
xdsServer.shutdownNow();
}
}
Loading