Skip to content

Commit

Permalink
ref #175
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Furer committed Oct 28, 2021
1 parent a0a6583 commit f4e3b5a
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 80 deletions.
15 changes: 11 additions & 4 deletions grpc-spring-boot-starter-demo/src/test/resources/logback-test.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
<configuration>
<include resource="org/springframework/boot/logging/logback/base.xml"/>
<logger name="org.springframework" level="INFO"/>


<include resource="/org/springframework/boot/logging/logback/base.xml"/>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n
</pattern>
</encoder>
</appender>
<root level="error">
<appender-ref ref="STDOUT"/>
</root>
<logger name="org.lognet" level="debug"/>
</configuration>
Original file line number Diff line number Diff line change
@@ -1,30 +1,54 @@
package org.lognet.springboot.grpc;

import io.grpc.BindableService;
import io.grpc.MethodDescriptor;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import lombok.Builder;
import lombok.Getter;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.MethodIntrospector;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.function.SingletonSupplier;

import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class GRpcServicesRegistry implements InitializingBean, ApplicationContextAware {
private ApplicationContext applicationContext;
@Getter
@Builder
public static class GrpcServiceMethod {
private BindableService service;
private Method method;

}

private ApplicationContext applicationContext;

private Supplier<Map<String, BindableService>> beanNameToServiceBean;

private Supplier<Map<String, BindableService>> serviceNameToServiceBean;

private Supplier<Collection<ServerInterceptor>> grpcGlobalInterceptors;

private Supplier<Map<MethodDescriptor<?, ?>, GrpcServiceMethod>> descriptorToServiceMethod;

private Supplier< Map<Method,MethodDescriptor<?,?>>> methodToDescriptor ;




/**
* @return service name to grpc service bean
Expand All @@ -45,6 +69,14 @@ Collection<ServerInterceptor> getGlobalInterceptors() {
return grpcGlobalInterceptors.get();
}

public GrpcServiceMethod getGrpServiceMethod(MethodDescriptor<?,?> descriptor) {
return descriptorToServiceMethod.get().get(descriptor);
}

public MethodDescriptor<?,?> getMethodDescriptor( Method method) {
return methodToDescriptor.get().get(method);
}

private <T> Map<String, T> getBeanNamesByTypeWithAnnotation(Class<? extends Annotation> annotationType, Class<T> beanType) {

return applicationContext.getBeansWithAnnotation(annotationType)
Expand All @@ -58,7 +90,14 @@ private <T> Map<String, T> getBeanNamesByTypeWithAnnotation(Class<? extends Anno
@Override
public void afterPropertiesSet() throws Exception {

descriptorToServiceMethod = SingletonSupplier.of(this::descriptorToServiceMethod);

methodToDescriptor = SingletonSupplier.of(()->
descriptorToServiceMethod.get()
.entrySet()
.stream()
.collect(Collectors.toMap(e->e.getValue().getMethod(), Map.Entry::getKey))
);
beanNameToServiceBean = SingletonSupplier.of(() ->
getBeanNamesByTypeWithAnnotation(GRpcService.class, BindableService.class)
);
Expand All @@ -78,8 +117,49 @@ public void afterPropertiesSet() throws Exception {
);
}


@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}

private Map<MethodDescriptor<?, ?>, GrpcServiceMethod> descriptorToServiceMethod (){
final Map<MethodDescriptor<?, ?>, GrpcServiceMethod> map = new HashMap<>();

Function<String, ReflectionUtils.MethodFilter> filterFactory = name ->
method -> method.getName().equalsIgnoreCase(name) ;

for (BindableService service : getBeanNameToServiceBeanMap().values()) {
final ServerServiceDefinition serviceDefinition = service.bindService();
for (MethodDescriptor<?, ?> d : serviceDefinition.getServiceDescriptor().getMethods()) {
Class<?> abstractBaseClass = service.getClass();
while (!Modifier.isAbstract(abstractBaseClass.getModifiers())){
abstractBaseClass = abstractBaseClass.getSuperclass();
}

final Set<Method> methods = MethodIntrospector
.selectMethods(abstractBaseClass, filterFactory.apply(d.getBareMethodName()));


switch (methods.size()){
case 0:
throw new IllegalStateException("Method " +d.getBareMethodName()+ "not found in service "+ serviceDefinition.getServiceDescriptor().getName());
case 1:
map.put(d, GrpcServiceMethod.builder()
.service(service)
.method(methods.iterator().next())
.build());
break;
default:
throw new IllegalStateException("Ambiguous method " +d.getBareMethodName()+ " in service "+ serviceDefinition.getServiceDescriptor().getName());
}





}
}
return Collections.unmodifiableMap(map);
}
}
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
package org.lognet.springboot.grpc.security;

import io.grpc.BindableService;
import io.grpc.MethodDescriptor;
import io.grpc.ServerMethodDefinition;
import org.lognet.springboot.grpc.GRpcServicesRegistry;
import org.springframework.security.access.ConfigAttribute;
import org.springframework.security.access.method.MethodSecurityMetadataSource;

import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class GrpcSecurityMetadataSource implements MethodSecurityMetadataSource {
private Map<MethodDescriptor<?,?>, List<ConfigAttribute>> methodDescriptorMap;
private Map<Method,MethodDescriptor<?,?>> methodMap = new HashMap<>();

public GrpcSecurityMetadataSource(GRpcServicesRegistry registry , Map<MethodDescriptor<?, ?>, List<ConfigAttribute>> methodDescriptorMap) {
this.methodDescriptorMap = methodDescriptorMap;

for(BindableService s:registry.getBeanNameToServiceBeanMap().values()){
for(ServerMethodDefinition<?,?> md :s.bindService().getMethods()){
final Method method = Stream.of(s.getClass().getMethods())
.filter(m -> md.getMethodDescriptor().getBareMethodName().equalsIgnoreCase(m.getName()))
.findFirst().get();
methodMap.put(method,md.getMethodDescriptor());
}
}
private Map<MethodDescriptor<?,?>, List<ConfigAttribute>> methodDescriptorAttributes;
private GRpcServicesRegistry registry;


public GrpcSecurityMetadataSource(GRpcServicesRegistry registry , Map<MethodDescriptor<?, ?>, List<ConfigAttribute>> methodDescriptorAttributes) {
this.methodDescriptorAttributes = methodDescriptorAttributes;
this.registry = registry;


}

@Override
public Collection<ConfigAttribute> getAttributes(Object object) throws IllegalArgumentException {
final MethodDescriptor methodDescriptor = SecurityInterceptor.GrpcMethodInvocation.class.cast(object).getCall().getMethodDescriptor();
return methodDescriptorMap.get(methodDescriptor);
return methodDescriptorAttributes.get(methodDescriptor);
}

@Override
public Collection<ConfigAttribute> getAllConfigAttributes() {
return methodDescriptorMap
return methodDescriptorAttributes
.values()
.stream()
.flatMap(Collection::stream)
Expand All @@ -54,7 +45,7 @@ public boolean supports(Class<?> clazz) {

@Override
public Collection<ConfigAttribute> getAttributes(Method method, Class<?> targetClass) {
final MethodDescriptor<?, ?> methodDescriptor = methodMap.get(method);
return methodDescriptorMap.get(methodDescriptor);
final MethodDescriptor<?, ?> methodDescriptor = registry.getMethodDescriptor(method);
return methodDescriptorAttributes.get(methodDescriptor);
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package org.lognet.springboot.grpc.security;

import io.grpc.BindableService;
import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.ForwardingServerCallListener;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -20,7 +17,6 @@
import org.lognet.springboot.grpc.autoconfigure.GRpcServerProperties;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.Ordered;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.access.SecurityMetadataSource;
Expand All @@ -31,18 +27,10 @@
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.util.SimpleMethodInvocation;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.AbstractMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

@Slf4j
public class SecurityInterceptor extends AbstractSecurityInterceptor implements ServerInterceptor, Ordered {
Expand All @@ -57,7 +45,9 @@ public class SecurityInterceptor extends AbstractSecurityInterceptor implements
private GRpcServerProperties.SecurityProperties.Auth authCfg;

private FailureHandlingSupport failureHandlingSupport;
private Map<GrpcServiceMethodKey, Map.Entry<Object, Method>> keyedMethods;

private GRpcServicesRegistry registry;


static class GrpcMethodInvocation<ReqT, RespT> extends SimpleMethodInvocation {
final private ServerCall<ReqT, RespT> call;
Expand All @@ -67,8 +57,8 @@ static class GrpcMethodInvocation<ReqT, RespT> extends SimpleMethodInvocation {
@Setter
private Object[] arguments;

public GrpcMethodInvocation(Map.Entry<Object, Method> handler, ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
super(handler.getKey(), handler.getValue());
public GrpcMethodInvocation(GRpcServicesRegistry.GrpcServiceMethod serviceMethod, ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
super(serviceMethod.getService(), serviceMethod.getMethod());
this.call = call;
this.headers = headers;
this.next = next;
Expand All @@ -84,22 +74,7 @@ ServerCall<ReqT, RespT> getCall() {
}
}

@Getter
@EqualsAndHashCode
static class GrpcServiceMethodKey {

public GrpcServiceMethodKey(MethodDescriptor<?, ?> methodDescriptor) {
this.serviceName = methodDescriptor.getServiceName();
this.methodName = methodDescriptor.getBareMethodName();
}

@EqualsAndHashCode.Include
final private String serviceName;

@EqualsAndHashCode.Include
final private String methodName;

}


public SecurityInterceptor(SecurityMetadataSource securityMetadataSource, AuthenticationSchemeSelector schemeSelector) {
Expand All @@ -110,28 +85,8 @@ public SecurityInterceptor(SecurityMetadataSource securityMetadataSource, Authen

@Autowired
public void setGRpcServicesRegistry(GRpcServicesRegistry registry) {
this.registry = registry;

final Map<GrpcServiceMethodKey, Map.Entry<Object, Method>> map = new HashMap<>();

Function<String, ReflectionUtils.MethodFilter> filterFactory = name ->
method -> method.getName().equalsIgnoreCase(name) ;

for (BindableService service : registry.getBeanNameToServiceBeanMap().values()) {
for (MethodDescriptor<?, ?> d : service.bindService().getServiceDescriptor().getMethods()) {
Class<?> abstractBaseClass = service.getClass();
while (!Modifier.isAbstract(abstractBaseClass.getModifiers())){
abstractBaseClass = abstractBaseClass.getSuperclass();
}

final Method method = MethodIntrospector
.selectMethods(abstractBaseClass, filterFactory.apply(d.getBareMethodName()))
.iterator().next();
map.put(new GrpcServiceMethodKey(d),
new AbstractMap.SimpleImmutableEntry<>(service, method));

}
}
keyedMethods = Collections.unmodifiableMap(map);
}

@Autowired
Expand Down Expand Up @@ -294,9 +249,9 @@ private <RespT, ReqT> Context setupGRpcSecurityContext(ServerCall<RespT, ReqT> c
context.setAuthentication(authentication);
SecurityContextHolder.setContext(context);

final Map.Entry<Object, Method> methodHandler = keyedMethods.get(new GrpcServiceMethodKey(call.getMethodDescriptor()));
final GRpcServicesRegistry.GrpcServiceMethod grpcServiceMethod = registry.getGrpServiceMethod(call.getMethodDescriptor());

final GrpcMethodInvocation<RespT, ReqT> methodInvocation = new GrpcMethodInvocation<>(methodHandler, call, headers, next);
final GrpcMethodInvocation<RespT, ReqT> methodInvocation = new GrpcMethodInvocation<>(grpcServiceMethod , call, headers, next);
final InterceptorStatusToken interceptorStatusToken = beforeInvocation(methodInvocation);

return Context.current()
Expand Down

0 comments on commit f4e3b5a

Please sign in to comment.