Skip to content

Commit 6dbd1b9

Browse files
authored
Add newAttachMetadataServerInterceptor() MetadataUtil (#11458)
1 parent 6a9bc3b commit 6dbd1b9

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

stub/src/main/java/io/grpc/stub/MetadataUtils.java

+64
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222
import io.grpc.Channel;
2323
import io.grpc.ClientCall;
2424
import io.grpc.ClientInterceptor;
25+
import io.grpc.ExperimentalApi;
2526
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
2627
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
28+
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
2729
import io.grpc.Metadata;
2830
import io.grpc.MethodDescriptor;
31+
import io.grpc.ServerCall;
32+
import io.grpc.ServerCallHandler;
33+
import io.grpc.ServerInterceptor;
2934
import io.grpc.Status;
3035
import java.util.concurrent.atomic.AtomicReference;
3136

@@ -143,4 +148,63 @@ public void onClose(Status status, Metadata trailers) {
143148
}
144149
}
145150
}
151+
152+
/**
153+
* Returns a ServerInterceptor that adds the specified Metadata to every response stream, one way
154+
* or another.
155+
*
156+
* <p>If, absent this interceptor, a stream would have headers, 'extras' will be added to those
157+
* headers. Otherwise, 'extras' will be sent as trailers. This pattern is useful when you have
158+
* some fixed information, server identity say, that should be included no matter how the call
159+
* turns out. The fallback to trailers avoids artificially committing clients to error responses
160+
* that could otherwise be retried (see https://grpc.io/docs/guides/retry/ for more).
161+
*
162+
* <p>For correct operation, be sure to arrange for this interceptor to run *before* any others
163+
* that might add headers.
164+
*
165+
* @param extras the Metadata to be added to each stream. Caller gives up ownership.
166+
*/
167+
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11462")
168+
public static ServerInterceptor newAttachMetadataServerInterceptor(Metadata extras) {
169+
return new MetadataAttachingServerInterceptor(extras);
170+
}
171+
172+
private static final class MetadataAttachingServerInterceptor implements ServerInterceptor {
173+
174+
private final Metadata extras;
175+
176+
MetadataAttachingServerInterceptor(Metadata extras) {
177+
this.extras = extras;
178+
}
179+
180+
@Override
181+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
182+
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
183+
return next.startCall(new MetadataAttachingServerCall<>(call), headers);
184+
}
185+
186+
final class MetadataAttachingServerCall<ReqT, RespT>
187+
extends SimpleForwardingServerCall<ReqT, RespT> {
188+
boolean headersSent;
189+
190+
MetadataAttachingServerCall(ServerCall<ReqT, RespT> delegate) {
191+
super(delegate);
192+
}
193+
194+
@Override
195+
public void sendHeaders(Metadata headers) {
196+
headers.merge(extras);
197+
headersSent = true;
198+
super.sendHeaders(headers);
199+
}
200+
201+
@Override
202+
public void close(Status status, Metadata trailers) {
203+
if (!headersSent) {
204+
trailers.merge(extras);
205+
}
206+
super.close(status, trailers);
207+
}
208+
}
209+
}
146210
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright 2024 The gRPC Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.grpc.stub;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static io.grpc.stub.MetadataUtils.newAttachMetadataServerInterceptor;
21+
import static io.grpc.stub.MetadataUtils.newCaptureMetadataInterceptor;
22+
import static org.junit.Assert.fail;
23+
24+
import com.google.common.collect.ImmutableList;
25+
import io.grpc.CallOptions;
26+
import io.grpc.ManagedChannel;
27+
import io.grpc.Metadata;
28+
import io.grpc.MethodDescriptor;
29+
import io.grpc.ServerCallHandler;
30+
import io.grpc.ServerInterceptors;
31+
import io.grpc.ServerMethodDefinition;
32+
import io.grpc.ServerServiceDefinition;
33+
import io.grpc.Status;
34+
import io.grpc.Status.Code;
35+
import io.grpc.StatusRuntimeException;
36+
import io.grpc.StringMarshaller;
37+
import io.grpc.inprocess.InProcessChannelBuilder;
38+
import io.grpc.inprocess.InProcessServerBuilder;
39+
import io.grpc.testing.GrpcCleanupRule;
40+
import java.io.IOException;
41+
import java.util.Iterator;
42+
import java.util.concurrent.atomic.AtomicReference;
43+
import org.junit.Rule;
44+
import org.junit.Test;
45+
import org.junit.runner.RunWith;
46+
import org.junit.runners.JUnit4;
47+
48+
@RunWith(JUnit4.class)
49+
public class MetadataUtilsTest {
50+
51+
@Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
52+
53+
private static final String SERVER_NAME = "test";
54+
private static final Metadata.Key<String> FOO_KEY =
55+
Metadata.Key.of("foo-key", Metadata.ASCII_STRING_MARSHALLER);
56+
57+
private final MethodDescriptor<String, String> echoMethod =
58+
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
59+
.setFullMethodName("test/echo")
60+
.setType(MethodDescriptor.MethodType.UNARY)
61+
.build();
62+
63+
private final ServerCallHandler<String, String> echoCallHandler =
64+
ServerCalls.asyncUnaryCall(
65+
(req, respObserver) -> {
66+
respObserver.onNext(req);
67+
respObserver.onCompleted();
68+
});
69+
70+
MethodDescriptor<String, String> echoServerStreamingMethod =
71+
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
72+
.setFullMethodName("test/echoStream")
73+
.setType(MethodDescriptor.MethodType.SERVER_STREAMING)
74+
.build();
75+
76+
private final AtomicReference<Metadata> trailersCapture = new AtomicReference<>();
77+
private final AtomicReference<Metadata> headersCapture = new AtomicReference<>();
78+
79+
@Test
80+
public void shouldAttachHeadersToResponse() throws IOException {
81+
Metadata extras = new Metadata();
82+
extras.put(FOO_KEY, "foo-value");
83+
84+
ServerServiceDefinition serviceDef =
85+
ServerInterceptors.intercept(
86+
ServerServiceDefinition.builder("test").addMethod(echoMethod, echoCallHandler).build(),
87+
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));
88+
89+
grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());
90+
ManagedChannel channel =
91+
grpcCleanup.register(
92+
newInProcessChannelBuilder()
93+
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
94+
.build());
95+
96+
String response =
97+
ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello");
98+
assertThat(response).isEqualTo("hello");
99+
assertThat(trailersCapture.get() == null || !trailersCapture.get().containsKey(FOO_KEY))
100+
.isTrue();
101+
assertThat(headersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
102+
}
103+
104+
@Test
105+
public void shouldAttachTrailersWhenNoResponse() throws IOException {
106+
Metadata extras = new Metadata();
107+
extras.put(FOO_KEY, "foo-value");
108+
109+
ServerServiceDefinition serviceDef =
110+
ServerInterceptors.intercept(
111+
ServerServiceDefinition.builder("test")
112+
.addMethod(
113+
ServerMethodDefinition.create(
114+
echoServerStreamingMethod,
115+
ServerCalls.asyncUnaryCall(
116+
(req, respObserver) -> respObserver.onCompleted())))
117+
.build(),
118+
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));
119+
grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());
120+
121+
ManagedChannel channel =
122+
grpcCleanup.register(
123+
newInProcessChannelBuilder()
124+
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
125+
.build());
126+
127+
Iterator<String> response =
128+
ClientCalls.blockingServerStreamingCall(
129+
channel, echoServerStreamingMethod, CallOptions.DEFAULT, "hello");
130+
assertThat(response.hasNext()).isFalse();
131+
assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue();
132+
assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
133+
}
134+
135+
@Test
136+
public void shouldAttachTrailersToErrorResponse() throws IOException {
137+
Metadata extras = new Metadata();
138+
extras.put(FOO_KEY, "foo-value");
139+
140+
ServerServiceDefinition serviceDef =
141+
ServerInterceptors.intercept(
142+
ServerServiceDefinition.builder("test")
143+
.addMethod(
144+
echoMethod,
145+
ServerCalls.asyncUnaryCall(
146+
(req, respObserver) ->
147+
respObserver.onError(Status.INVALID_ARGUMENT.asRuntimeException())))
148+
.build(),
149+
ImmutableList.of(newAttachMetadataServerInterceptor(extras)));
150+
grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start());
151+
152+
ManagedChannel channel =
153+
grpcCleanup.register(
154+
newInProcessChannelBuilder()
155+
.intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture))
156+
.build());
157+
try {
158+
ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello");
159+
fail();
160+
} catch (StatusRuntimeException e) {
161+
assertThat(e.getStatus()).isNotNull();
162+
assertThat(e.getStatus().getCode()).isEqualTo(Code.INVALID_ARGUMENT);
163+
}
164+
assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue();
165+
assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value");
166+
}
167+
168+
private static InProcessServerBuilder newInProcessServerBuilder() {
169+
return InProcessServerBuilder.forName(SERVER_NAME).directExecutor();
170+
}
171+
172+
private static InProcessChannelBuilder newInProcessChannelBuilder() {
173+
return InProcessChannelBuilder.forName(SERVER_NAME).directExecutor();
174+
}
175+
}

0 commit comments

Comments
 (0)