Skip to content

Commit c167ead

Browse files
xds: Per-rpc rewriting of the authority header based on the selected route. (#11631)
Implementation of A81.
1 parent 3562380 commit c167ead

28 files changed

+875
-309
lines changed

api/src/main/java/io/grpc/LoadBalancer.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ public static final class PickResult {
552552
private final Status status;
553553
// True if the result is created by withDrop()
554554
private final boolean drop;
555+
@Nullable private final String authorityOverride;
555556

556557
private PickResult(
557558
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
@@ -560,6 +561,17 @@ private PickResult(
560561
this.streamTracerFactory = streamTracerFactory;
561562
this.status = checkNotNull(status, "status");
562563
this.drop = drop;
564+
this.authorityOverride = null;
565+
}
566+
567+
private PickResult(
568+
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
569+
Status status, boolean drop, @Nullable String authorityOverride) {
570+
this.subchannel = subchannel;
571+
this.streamTracerFactory = streamTracerFactory;
572+
this.status = checkNotNull(status, "status");
573+
this.drop = drop;
574+
this.authorityOverride = authorityOverride;
563575
}
564576

565577
/**
@@ -639,6 +651,19 @@ public static PickResult withSubchannel(
639651
false);
640652
}
641653

654+
/**
655+
* Same as {@code withSubchannel(subchannel, streamTracerFactory)} but with an authority name
656+
* to override in the host header.
657+
*/
658+
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656")
659+
public static PickResult withSubchannel(
660+
Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
661+
@Nullable String authorityOverride) {
662+
return new PickResult(
663+
checkNotNull(subchannel, "subchannel"), streamTracerFactory, Status.OK,
664+
false, authorityOverride);
665+
}
666+
642667
/**
643668
* Equivalent to {@code withSubchannel(subchannel, null)}.
644669
*
@@ -682,6 +707,13 @@ public static PickResult withNoResult() {
682707
return NO_RESULT;
683708
}
684709

710+
/** Returns the authority override if any. */
711+
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656")
712+
@Nullable
713+
public String getAuthorityOverride() {
714+
return authorityOverride;
715+
}
716+
685717
/**
686718
* The Subchannel if this result was created by {@link #withSubchannel withSubchannel()}, or
687719
* null otherwise.
@@ -736,6 +768,7 @@ public String toString() {
736768
.add("streamTracerFactory", streamTracerFactory)
737769
.add("status", status)
738770
.add("drop", drop)
771+
.add("authority-override", authorityOverride)
739772
.toString();
740773
}
741774

core/src/main/java/io/grpc/internal/DelayedClientTransport.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,17 @@ public final ClientStream newStream(
131131
}
132132
if (state.lastPicker != null) {
133133
PickResult pickResult = state.lastPicker.pickSubchannel(args);
134+
callOptions = args.getCallOptions();
135+
// User code provided authority takes precedence over the LB provided one.
136+
if (callOptions.getAuthority() == null
137+
&& pickResult.getAuthorityOverride() != null) {
138+
callOptions = callOptions.withAuthority(pickResult.getAuthorityOverride());
139+
}
134140
ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
135141
callOptions.isWaitForReady());
136142
if (transport != null) {
137143
return transport.newStream(
138-
args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(),
144+
args.getMethodDescriptor(), args.getHeaders(), callOptions,
139145
tracers);
140146
}
141147
}
@@ -281,6 +287,10 @@ final void reprocess(@Nullable SubchannelPicker picker) {
281287
for (final PendingStream stream : toProcess) {
282288
PickResult pickResult = picker.pickSubchannel(stream.args);
283289
CallOptions callOptions = stream.args.getCallOptions();
290+
// User code provided authority takes precedence over the LB provided one.
291+
if (callOptions.getAuthority() == null && pickResult.getAuthorityOverride() != null) {
292+
stream.setAuthority(pickResult.getAuthorityOverride());
293+
}
284294
final ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
285295
callOptions.isWaitForReady());
286296
if (transport != null) {

core/src/main/java/io/grpc/internal/DelayedStream.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ private void delayOrExecute(Runnable runnable) {
208208

209209
@Override
210210
public void setAuthority(final String authority) {
211-
checkState(listener == null, "May only be called before start");
212211
checkNotNull(authority, "authority");
213212
preStartPendingCalls.add(new Runnable() {
214213
@Override

core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,43 @@ public void uncaughtException(Thread t, Throwable e) {
502502
verify(transportListener).transportTerminated();
503503
}
504504

505+
@Test
506+
public void reprocess_authorityOverridePresentInCallOptions_authorityOverrideFromLbIsIgnored() {
507+
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
508+
method, headers, callOptions, tracers);
509+
delayedStream.start(mock(ClientStreamListener.class));
510+
SubchannelPicker picker = mock(SubchannelPicker.class);
511+
PickResult pickResult = PickResult.withSubchannel(
512+
mockSubchannel, null, "authority-override-hostname-from-lb");
513+
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
514+
515+
delayedTransport.reprocess(picker);
516+
fakeExecutor.runDueTasks();
517+
518+
verify(mockRealStream, never()).setAuthority("authority-override-hostname-from-lb");
519+
}
520+
521+
@Test
522+
public void
523+
reprocess_authorityOverrideNotInCallOptions_authorityOverrideFromLbIsSetIntoStream() {
524+
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
525+
method, headers, callOptions.withAuthority(null), tracers);
526+
delayedStream.start(mock(ClientStreamListener.class));
527+
SubchannelPicker picker = mock(SubchannelPicker.class);
528+
PickResult pickResult = PickResult.withSubchannel(
529+
mockSubchannel, null, "authority-override-hostname-from-lb");
530+
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
531+
when(mockRealTransport.newStream(
532+
same(method), same(headers), any(CallOptions.class),
533+
ArgumentMatchers.any()))
534+
.thenReturn(mockRealStream);
535+
536+
delayedTransport.reprocess(picker);
537+
fakeExecutor.runDueTasks();
538+
539+
verify(mockRealStream).setAuthority("authority-override-hostname-from-lb");
540+
}
541+
505542
@Test
506543
public void reprocess_NoPendingStream() {
507544
SubchannelPicker picker = mock(SubchannelPicker.class);
@@ -525,6 +562,55 @@ public void reprocess_NoPendingStream() {
525562
assertSame(mockRealStream, stream);
526563
}
527564

565+
@Test
566+
public void newStream_assignsTransport_authorityFromCallOptionsSupersedesAuthorityFromLB() {
567+
SubchannelPicker picker = mock(SubchannelPicker.class);
568+
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
569+
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
570+
PickResult pickResult = PickResult.withSubchannel(
571+
subchannel, null, "authority-override-hostname-from-lb");
572+
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
573+
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
574+
ArgumentCaptor.forClass(CallOptions.class);
575+
when(mockRealTransport.newStream(
576+
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
577+
ArgumentMatchers.<ClientStreamTracer[]>any()))
578+
.thenReturn(mockRealStream);
579+
delayedTransport.reprocess(picker);
580+
verifyNoMoreInteractions(picker);
581+
verifyNoMoreInteractions(transportListener);
582+
583+
CallOptions callOptions =
584+
CallOptions.DEFAULT.withAuthority("authority-override-hosstname-from-calloptions");
585+
delayedTransport.newStream(method, headers, callOptions, tracers);
586+
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
587+
"authority-override-hosstname-from-calloptions");
588+
}
589+
590+
@Test
591+
public void newStream_assignsTransport_authorityFromLB() {
592+
SubchannelPicker picker = mock(SubchannelPicker.class);
593+
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
594+
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
595+
PickResult pickResult = PickResult.withSubchannel(
596+
subchannel, null, "authority-override-hostname-from-lb");
597+
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
598+
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
599+
ArgumentCaptor.forClass(CallOptions.class);
600+
when(mockRealTransport.newStream(
601+
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
602+
ArgumentMatchers.<ClientStreamTracer[]>any()))
603+
.thenReturn(mockRealStream);
604+
delayedTransport.reprocess(picker);
605+
verifyNoMoreInteractions(picker);
606+
verifyNoMoreInteractions(transportListener);
607+
608+
CallOptions callOptions = CallOptions.DEFAULT;
609+
delayedTransport.newStream(method, headers, callOptions, tracers);
610+
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
611+
"authority-override-hostname-from-lb");
612+
}
613+
528614
@Test
529615
public void reprocess_newStreamRacesWithReprocess() throws Exception {
530616
final CyclicBarrier barrier = new CyclicBarrier(2);

core/src/test/java/io/grpc/internal/DelayedStreamTest.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,6 @@ public void setStream_setAuthority() {
8484
inOrder.verify(realStream).start(any(ClientStreamListener.class));
8585
}
8686

87-
@Test(expected = IllegalStateException.class)
88-
public void setAuthority_afterStart() {
89-
stream.start(listener);
90-
stream.setAuthority("notgonnawork");
91-
}
92-
9387
@Test(expected = IllegalStateException.class)
9488
public void start_afterStart() {
9589
stream.start(listener);

xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import io.grpc.Metadata;
3535
import io.grpc.Status;
3636
import io.grpc.internal.ForwardingClientStreamTracer;
37+
import io.grpc.internal.GrpcUtil;
3738
import io.grpc.internal.ObjectPool;
3839
import io.grpc.services.MetricReport;
3940
import io.grpc.util.ForwardingLoadBalancerHelper;
@@ -231,10 +232,16 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
231232
args.getAddresses().get(0).getAttributes());
232233
AtomicReference<ClusterLocality> localityAtomicReference = new AtomicReference<>(
233234
clusterLocality);
234-
Attributes attrs = args.getAttributes().toBuilder()
235-
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference)
236-
.build();
237-
args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build();
235+
Attributes.Builder attrsBuilder = args.getAttributes().toBuilder()
236+
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference);
237+
if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) {
238+
String hostname = args.getAddresses().get(0).getAttributes()
239+
.get(InternalXdsAttributes.ATTR_ADDRESS_NAME);
240+
if (hostname != null) {
241+
attrsBuilder.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, hostname);
242+
}
243+
}
244+
args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build();
238245
final Subchannel subchannel = delegate().createSubchannel(args);
239246

240247
return new ForwardingSubchannel() {
@@ -389,7 +396,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
389396
Status.UNAVAILABLE.withDescription("Dropped: " + dropOverload.category()));
390397
}
391398
}
392-
final PickResult result = delegate.pickSubchannel(args);
399+
PickResult result = delegate.pickSubchannel(args);
393400
if (result.getStatus().isOk() && result.getSubchannel() != null) {
394401
if (enableCircuitBreaking) {
395402
if (inFlights.get() >= maxConcurrentRequests) {
@@ -415,9 +422,17 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
415422
stats, inFlights, result.getStreamTracerFactory());
416423
ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
417424
.newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
418-
return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory);
425+
result = PickResult.withSubchannel(result.getSubchannel(),
426+
orcaTracerFactory);
419427
}
420428
}
429+
if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null
430+
&& args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)) {
431+
result = PickResult.withSubchannel(result.getSubchannel(),
432+
result.getStreamTracerFactory(),
433+
result.getSubchannel().getAttributes().get(
434+
InternalXdsAttributes.ATTR_ADDRESS_NAME));
435+
}
421436
}
422437
return result;
423438
}

xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ public void run() {
428428
.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT,
429429
localityLbInfo.localityWeight())
430430
.set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight)
431+
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname())
431432
.build();
432433
EquivalentAddressGroup eag = new EquivalentAddressGroup(
433434
endpoint.eag().getAddresses(), attr);
@@ -567,7 +568,7 @@ void start() {
567568
handleEndpointResolutionError();
568569
return;
569570
}
570-
resolver.start(new NameResolverListener());
571+
resolver.start(new NameResolverListener(dnsHostName));
571572
}
572573

573574
void refresh() {
@@ -606,6 +607,12 @@ public void run() {
606607
}
607608

608609
private class NameResolverListener extends NameResolver.Listener2 {
610+
private final String dnsHostName;
611+
612+
NameResolverListener(String dnsHostName) {
613+
this.dnsHostName = dnsHostName;
614+
}
615+
609616
@Override
610617
public void onResult(final ResolutionResult resolutionResult) {
611618
class NameResolved implements Runnable {
@@ -625,6 +632,7 @@ public void run() {
625632
Attributes attr = eag.getAttributes().toBuilder()
626633
.set(InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY)
627634
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName)
635+
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, dnsHostName)
628636
.build();
629637
eag = new EquivalentAddressGroup(eag.getAddresses(), attr);
630638
eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName));

xds/src/main/java/io/grpc/xds/Endpoints.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,19 @@ abstract static class LbEndpoint {
6161
// Whether the endpoint is healthy.
6262
abstract boolean isHealthy();
6363

64+
abstract String hostname();
65+
6466
static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight,
65-
boolean isHealthy) {
66-
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy);
67+
boolean isHealthy, String hostname) {
68+
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy, hostname);
6769
}
6870

6971
// Only for testing.
7072
@VisibleForTesting
7173
static LbEndpoint create(
72-
String address, int port, int loadBalancingWeight, boolean isHealthy) {
74+
String address, int port, int loadBalancingWeight, boolean isHealthy, String hostname) {
7375
return LbEndpoint.create(new EquivalentAddressGroup(new InetSocketAddress(address, port)),
74-
loadBalancingWeight, isHealthy);
76+
loadBalancingWeight, isHealthy, hostname);
7577
}
7678
}
7779

xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ public final class InternalXdsAttributes {
9191
static final Attributes.Key<Long> ATTR_SERVER_WEIGHT =
9292
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight");
9393

94+
/** Name associated with individual address, if available (e.g., DNS name). */
95+
@EquivalentAddressGroup.Attr
96+
static final Attributes.Key<String> ATTR_ADDRESS_NAME =
97+
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.addressName");
98+
9499
/**
95100
* Filter chain match for network filters.
96101
*/

0 commit comments

Comments
 (0)