Skip to content

Commit 4b458cb

Browse files
authored
Block onHalfClose if onMessage was blocked (#241)
* test that SecurityInterceptor does not propagate onHalfClose after closing the request * set the locale in ValidationTest so it passes on non-English systems * test that ValidatingInterceptor does not propagate onHalfClose after closing the request * also block onHalfClose after blocking onMessage fixes #240
1 parent b56d6e4 commit 4b458cb

File tree

6 files changed

+140
-23
lines changed

6 files changed

+140
-23
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.lognet.springboot.grpc;
2+
3+
import io.grpc.ForwardingServerCallListener;
4+
import io.grpc.Metadata;
5+
import io.grpc.ServerCall;
6+
import io.grpc.ServerCallHandler;
7+
import io.grpc.ServerInterceptor;
8+
9+
@GRpcGlobalInterceptor
10+
public class HalfCloseInterceptor implements ServerInterceptor {
11+
@Override
12+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
13+
ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next
14+
) {
15+
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
16+
@Override
17+
public void onHalfClose() {
18+
HalfCloseInterceptor.this.onHalfClose();
19+
super.onHalfClose();
20+
}
21+
};
22+
}
23+
24+
public void onHalfClose() {}
25+
}

grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/ValidationTest.java

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,39 @@
11
package org.lognet.springboot.grpc;
22

3+
import static org.hamcrest.MatcherAssert.assertThat;
4+
import static org.hamcrest.Matchers.emptyOrNullString;
5+
import static org.junit.Assert.assertThrows;
6+
import static org.mockito.Mockito.never;
7+
import static org.mockito.Mockito.verify;
8+
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE;
9+
10+
import java.util.Locale;
11+
312
import io.grpc.Metadata;
413
import io.grpc.Status;
514
import io.grpc.StatusRuntimeException;
615
import io.grpc.examples.GreeterGrpc;
716
import io.grpc.examples.GreeterOuterClass;
817
import org.hamcrest.Matchers;
18+
import org.junit.AfterClass;
919
import org.junit.Before;
20+
import org.junit.BeforeClass;
1021
import org.junit.Test;
1122
import org.junit.runner.RunWith;
1223
import org.lognet.springboot.grpc.demo.DemoApp;
1324
import org.springframework.boot.test.context.SpringBootTest;
1425
import org.springframework.boot.test.context.TestConfiguration;
26+
import org.springframework.boot.test.mock.mockito.SpyBean;
1527
import org.springframework.context.annotation.Bean;
1628
import org.springframework.context.annotation.Import;
1729
import org.springframework.test.context.ActiveProfiles;
1830
import org.springframework.test.context.junit4.SpringRunner;
1931

20-
import static org.hamcrest.MatcherAssert.assertThat;
21-
import static org.hamcrest.Matchers.emptyOrNullString;
22-
import static org.junit.Assert.assertThrows;
23-
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE;
24-
2532
@RunWith(SpringRunner.class)
2633
@SpringBootTest(classes = {DemoApp.class}, webEnvironment = NONE, properties = {"grpc.port=0"})
2734
@Import(ValidationTest.TestCfg.class)
2835
@ActiveProfiles("disable-security")
2936
public class ValidationTest extends GrpcServerTestBase {
30-
3137
@TestConfiguration
3238
static class TestCfg {
3339
@Bean
@@ -43,7 +49,21 @@ public Status handle(Object message, Status status, Exception exception, Metadat
4349
}
4450
private GreeterGrpc.GreeterBlockingStub stub;
4551

52+
@SpyBean
53+
HalfCloseInterceptor halfCloseInterceptor;
4654

55+
private static Locale systemDefaultLocale;
56+
57+
@BeforeClass
58+
public static void setLocaleToEnglish() {
59+
systemDefaultLocale = Locale.getDefault();
60+
Locale.setDefault(Locale.ENGLISH);
61+
}
62+
63+
@AfterClass
64+
public static void resetDefaultLocale() {
65+
Locale.setDefault(systemDefaultLocale);
66+
}
4767

4868
@Before
4969
public void setUp() throws Exception {
@@ -147,7 +167,6 @@ public void validMessageValidationTest() {
147167
@Test
148168
public void invalidResponseMessageValidationTest() {
149169
StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> {
150-
151170
stub.helloPersonInvalidResponse(GreeterOuterClass.Person.newBuilder()
152171
.setAge(3)//valid
153172
.setName("Dexter")//valid
@@ -164,6 +183,17 @@ public void invalidResponseMessageValidationTest() {
164183

165184
}
166185

186+
@Test
187+
public void noHalfCloseAfterFailedValidation() {
188+
StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> {
189+
stub.helloPersonValidResponse(GreeterOuterClass.Person.newBuilder()
190+
.setAge(49)// valid
191+
.clearName() //invalid
192+
.build());
193+
});
194+
assertThat(e.getStatus().getCode(), Matchers.is(Status.Code.INVALID_ARGUMENT));
195+
verify(halfCloseInterceptor, never()).onHalfClose();
196+
}
167197

168198
String getFieldName(int fieldNumber) {
169199
return GreeterOuterClass.Person.getDescriptor().findFieldByNumber(fieldNumber).getName();
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package org.lognet.springboot.grpc.auth;
2+
3+
4+
import static org.hamcrest.MatcherAssert.assertThat;
5+
import static org.junit.Assert.assertThrows;
6+
import static org.mockito.Mockito.never;
7+
import static org.mockito.Mockito.verify;
8+
9+
import com.google.protobuf.Empty;
10+
import io.grpc.Status;
11+
import io.grpc.StatusRuntimeException;
12+
import io.grpc.examples.SecuredGreeterGrpc;
13+
import org.hamcrest.Matchers;
14+
import org.junit.Test;
15+
import org.junit.runner.RunWith;
16+
import org.lognet.springboot.grpc.GrpcServerTestBase;
17+
import org.lognet.springboot.grpc.HalfCloseInterceptor;
18+
import org.lognet.springboot.grpc.demo.DemoApp;
19+
import org.springframework.boot.test.context.SpringBootTest;
20+
import org.springframework.boot.test.mock.mockito.SpyBean;
21+
import org.springframework.test.context.junit4.SpringRunner;
22+
23+
@SpringBootTest(
24+
classes = DemoApp.class,
25+
properties = "grpc.security.auth.fail-fast=false"
26+
)
27+
@RunWith(SpringRunner.class)
28+
public class FailLateSecurityInterceptorTest extends GrpcServerTestBase {
29+
@SpyBean
30+
HalfCloseInterceptor halfCloseInterceptor;
31+
32+
@Test
33+
public void noHalfCloseOnFailedAuth() {
34+
final StatusRuntimeException statusRuntimeException = assertThrows(
35+
StatusRuntimeException.class,
36+
() -> SecuredGreeterGrpc.newBlockingStub(selectedChanel).sayAuthHello2(Empty.newBuilder().build()).getMessage()
37+
);
38+
assertThat(statusRuntimeException.getStatus().getCode(), Matchers.is(Status.Code.UNAUTHENTICATED));
39+
verify(halfCloseInterceptor, never()).onHalfClose();
40+
}
41+
}

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package org.lognet.springboot.grpc;
22

3+
import io.grpc.ForwardingServerCallListener;
34
import io.grpc.Metadata;
45
import io.grpc.ServerCall;
56
import io.grpc.ServerInterceptor;
67
import io.grpc.Status;
7-
import io.grpc.StatusRuntimeException;
88

99
public interface FailureHandlingServerInterceptor extends ServerInterceptor {
1010
default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?> call, Metadata headers, final Status status, Exception exception){
@@ -14,4 +14,25 @@ default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?
1414
call.close(statusToSend, responseHeaders);
1515

1616
}
17+
18+
class MessageBlockingServerCallListener<R> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<R> {
19+
private volatile boolean messageBlocked = false;
20+
21+
public MessageBlockingServerCallListener(ServerCall.Listener<R> delegate) {
22+
super(delegate);
23+
}
24+
25+
@Override
26+
public void onHalfClose() {
27+
// If the message was blocked, downstream never had a chance to react to it. Hence, the half-close signal would look like
28+
// an error to them. So we do not propagate the signal in that case.
29+
if (!messageBlocked) {
30+
super.onHalfClose();
31+
}
32+
}
33+
34+
protected void blockMessage() {
35+
messageBlocked = true;
36+
}
37+
}
1738
}

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package org.lognet.springboot.grpc.security;
22

3+
import java.nio.ByteBuffer;
4+
import java.nio.charset.StandardCharsets;
5+
import java.util.Optional;
6+
37
import io.grpc.Context;
48
import io.grpc.Contexts;
59
import io.grpc.ForwardingServerCall;
@@ -22,10 +26,6 @@
2226
import org.springframework.security.core.context.SecurityContext;
2327
import org.springframework.security.core.context.SecurityContextHolder;
2428

25-
import java.nio.ByteBuffer;
26-
import java.nio.charset.StandardCharsets;
27-
import java.util.Optional;
28-
2929
@Slf4j
3030
public class SecurityInterceptor extends AbstractSecurityInterceptor implements FailureHandlingServerInterceptor, Ordered {
3131

@@ -193,9 +193,10 @@ private <RespT, ReqT> ServerCall.Listener<ReqT> fail(ServerCallHandler<ReqT, Res
193193

194194

195195
} else {
196-
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
196+
return new MessageBlockingServerCallListener<ReqT>(next.startCall(call, headers)) {
197197
@Override
198198
public void onMessage(ReqT message) {
199+
blockMessage();
199200
closeCall(message, errorHandler, call, headers, status, exception);
200201
}
201202
};

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package org.lognet.springboot.grpc.validation;
22

3+
import java.util.Optional;
4+
import java.util.Set;
5+
import javax.validation.ConstraintViolation;
6+
import javax.validation.ConstraintViolationException;
7+
import javax.validation.Validator;
8+
39
import io.grpc.ForwardingServerCall;
4-
import io.grpc.ForwardingServerCallListener;
510
import io.grpc.Metadata;
611
import io.grpc.ServerCall;
712
import io.grpc.ServerCallHandler;
@@ -15,12 +20,6 @@
1520
import org.springframework.beans.factory.annotation.Autowired;
1621
import org.springframework.core.Ordered;
1722

18-
import javax.validation.ConstraintViolation;
19-
import javax.validation.ConstraintViolationException;
20-
import javax.validation.Validator;
21-
import java.util.Optional;
22-
import java.util.Set;
23-
2423

2524
public class ValidatingInterceptor implements FailureHandlingServerInterceptor, Ordered {
2625
private Validator validator;
@@ -53,14 +52,14 @@ public void sendMessage(RespT message) {
5352
}
5453
}
5554
}, headers);
56-
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(listener) {
55+
return new MessageBlockingServerCallListener<ReqT>(listener) {
5756

5857
@Override
5958
public void onMessage(ReqT message) {
6059
final Set<ConstraintViolation<ReqT>> violations = validator.validate(message, RequestMessage.class);
6160
if (!violations.isEmpty()) {
62-
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
63-
61+
blockMessage();
62+
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
6463
} else {
6564
super.onMessage(message);
6665
}

0 commit comments

Comments
 (0)