Skip to content

Commit 709affc

Browse files
committed
also block onHalfClose after blocking onMessage
fixes #240
1 parent 7d0c162 commit 709affc

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/FailureHandlingServerInterceptor.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package org.lognet.springboot.grpc;
22

3+
import io.grpc.ForwardingServerCallListener;
34
import io.grpc.Metadata;
45
import io.grpc.ServerCall;
56
import io.grpc.ServerInterceptor;
67
import io.grpc.Status;
7-
import io.grpc.StatusRuntimeException;
88

99
public interface FailureHandlingServerInterceptor extends ServerInterceptor {
1010
default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?> call, Metadata headers, final Status status, Exception exception){
@@ -14,4 +14,25 @@ default void closeCall(Object o, GRpcErrorHandler errorHandler, ServerCall<?, ?
1414
call.close(statusToSend, responseHeaders);
1515

1616
}
17+
18+
class MessageBlockingServerCallListener<R> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<R> {
19+
private volatile boolean messageBlocked = false;
20+
21+
public MessageBlockingServerCallListener(ServerCall.Listener<R> delegate) {
22+
super(delegate);
23+
}
24+
25+
@Override
26+
public void onHalfClose() {
27+
// If the message was blocked, downstream never had a chance to react to it. Hence, the half-close signal would look like
28+
// an error to them. So we do not propagate the signal in that case.
29+
if (!messageBlocked) {
30+
super.onHalfClose();
31+
}
32+
}
33+
34+
protected void blockMessage() {
35+
messageBlocked = true;
36+
}
37+
}
1738
}

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/security/SecurityInterceptor.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package org.lognet.springboot.grpc.security;
22

3+
import java.nio.ByteBuffer;
4+
import java.nio.charset.StandardCharsets;
5+
import java.util.Optional;
6+
37
import io.grpc.Context;
48
import io.grpc.Contexts;
59
import io.grpc.ForwardingServerCall;
@@ -22,10 +26,6 @@
2226
import org.springframework.security.core.context.SecurityContext;
2327
import org.springframework.security.core.context.SecurityContextHolder;
2428

25-
import java.nio.ByteBuffer;
26-
import java.nio.charset.StandardCharsets;
27-
import java.util.Optional;
28-
2929
@Slf4j
3030
public class SecurityInterceptor extends AbstractSecurityInterceptor implements FailureHandlingServerInterceptor, Ordered {
3131

@@ -193,9 +193,10 @@ private <RespT, ReqT> ServerCall.Listener<ReqT> fail(ServerCallHandler<ReqT, Res
193193

194194

195195
} else {
196-
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(call, headers)) {
196+
return new MessageBlockingServerCallListener<ReqT>(next.startCall(call, headers)) {
197197
@Override
198198
public void onMessage(ReqT message) {
199+
blockMessage();
199200
closeCall(message, errorHandler, call, headers, status, exception);
200201
}
201202
};

grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/validation/ValidatingInterceptor.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package org.lognet.springboot.grpc.validation;
22

3+
import java.util.Optional;
4+
import java.util.Set;
5+
import javax.validation.ConstraintViolation;
6+
import javax.validation.ConstraintViolationException;
7+
import javax.validation.Validator;
8+
39
import io.grpc.ForwardingServerCall;
4-
import io.grpc.ForwardingServerCallListener;
510
import io.grpc.Metadata;
611
import io.grpc.ServerCall;
712
import io.grpc.ServerCallHandler;
@@ -15,12 +20,6 @@
1520
import org.springframework.beans.factory.annotation.Autowired;
1621
import org.springframework.core.Ordered;
1722

18-
import javax.validation.ConstraintViolation;
19-
import javax.validation.ConstraintViolationException;
20-
import javax.validation.Validator;
21-
import java.util.Optional;
22-
import java.util.Set;
23-
2423

2524
public class ValidatingInterceptor implements FailureHandlingServerInterceptor, Ordered {
2625
private Validator validator;
@@ -53,14 +52,14 @@ public void sendMessage(RespT message) {
5352
}
5453
}
5554
}, headers);
56-
return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(listener) {
55+
return new MessageBlockingServerCallListener<ReqT>(listener) {
5756

5857
@Override
5958
public void onMessage(ReqT message) {
6059
final Set<ConstraintViolation<ReqT>> violations = validator.validate(message, RequestMessage.class);
6160
if (!violations.isEmpty()) {
62-
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
63-
61+
blockMessage();
62+
closeCall(message,errorHandler,call,headers,Status.INVALID_ARGUMENT,new ConstraintViolationException(violations));
6463
} else {
6564
super.onMessage(message);
6665
}

0 commit comments

Comments
 (0)