Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@
import static org.apache.dubbo.common.constants.LoggerCodeConstants.COMMON_IO_EXCEPTION;

public class SerializeSecurityConfigurator implements ScopeClassLoaderListener<ModuleModel> {
private final SerializeSecurityManager serializeSecurityManager;

private static final ErrorTypeAwareLogger logger =
private static final ErrorTypeAwareLogger LOGGER =
LoggerFactory.getErrorTypeAwareLogger(SerializeSecurityConfigurator.class);

private final Set<Type> markedTypeCache = new HashSet<>();

private final SerializeSecurityManager serializeSecurityManager;

private final ModuleModel moduleModel;

private final ClassHolder classHolder;
Expand Down Expand Up @@ -137,7 +139,7 @@ private void loadAllow(ClassLoader classLoader) {
Set<URL> urls = ClassLoaderResourceLoader.loadResources(SERIALIZE_ALLOW_LIST_FILE_PATH, classLoader);
for (URL u : urls) {
try {
logger.info("Read serialize allow list from " + u);
LOGGER.info("Read serialize allow list from " + u);
String[] lines = IOUtils.readLines(u.openStream());
for (String line : lines) {
line = line.trim();
Expand All @@ -147,7 +149,7 @@ private void loadAllow(ClassLoader classLoader) {
serializeSecurityManager.addToAlwaysAllowed(line);
}
} catch (IOException e) {
logger.error(
LOGGER.error(
COMMON_IO_EXCEPTION,
"",
"",
Expand All @@ -161,7 +163,7 @@ private void loadBlocked(ClassLoader classLoader) {
Set<URL> urls = ClassLoaderResourceLoader.loadResources(SERIALIZE_BLOCKED_LIST_FILE_PATH, classLoader);
for (URL u : urls) {
try {
logger.info("Read serialize blocked list from " + u);
LOGGER.info("Read serialize blocked list from " + u);
String[] lines = IOUtils.readLines(u.openStream());
for (String line : lines) {
line = line.trim();
Expand All @@ -171,7 +173,7 @@ private void loadBlocked(ClassLoader classLoader) {
serializeSecurityManager.addToDisAllowed(line);
}
} catch (IOException e) {
logger.error(
LOGGER.error(
COMMON_IO_EXCEPTION,
"",
"",
Expand Down Expand Up @@ -213,8 +215,9 @@ public synchronized void registerInterface(Class<?> clazz) {
return;
}

Set<Type> markedClass = new HashSet<>();
checkClass(markedClass, clazz);
if (!checkClass(clazz)) {
return;
}

addToAllow(clazz);

Expand All @@ -223,111 +226,111 @@ public synchronized void registerInterface(Class<?> clazz) {
for (Method method : methodsToExport) {
Class<?>[] parameterTypes = method.getParameterTypes();
for (Class<?> parameterType : parameterTypes) {
checkClass(markedClass, parameterType);
checkClass(parameterType);
}

Type[] genericParameterTypes = method.getGenericParameterTypes();
for (Type genericParameterType : genericParameterTypes) {
checkType(markedClass, genericParameterType);
checkType(genericParameterType);
}

Class<?> returnType = method.getReturnType();
checkClass(markedClass, returnType);
checkClass(returnType);

Type genericReturnType = method.getGenericReturnType();
checkType(markedClass, genericReturnType);
checkType(genericReturnType);

Class<?>[] exceptionTypes = method.getExceptionTypes();
for (Class<?> exceptionType : exceptionTypes) {
checkClass(markedClass, exceptionType);
checkClass(exceptionType);
}

Type[] genericExceptionTypes = method.getGenericExceptionTypes();
for (Type genericExceptionType : genericExceptionTypes) {
checkType(markedClass, genericExceptionType);
checkType(genericExceptionType);
}
}
}

private void checkType(Set<Type> markedClass, Type type) {
private void checkType(Type type) {
if (type == null) {
return;
}

if (type instanceof Class) {
checkClass(markedClass, (Class<?>) type);
checkClass((Class<?>) type);
return;
}

if (!markedClass.add(type)) {
if (!markedTypeCache.add(type)) {
return;
}

if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
checkClass(markedClass, (Class<?>) parameterizedType.getRawType());
checkClass((Class<?>) parameterizedType.getRawType());
for (Type actualTypeArgument : parameterizedType.getActualTypeArguments()) {
checkType(markedClass, actualTypeArgument);
checkType(actualTypeArgument);
}
} else if (type instanceof GenericArrayType) {
GenericArrayType genericArrayType = (GenericArrayType) type;
checkType(markedClass, genericArrayType.getGenericComponentType());
checkType(genericArrayType.getGenericComponentType());
} else if (type instanceof TypeVariable) {
TypeVariable typeVariable = (TypeVariable) type;
for (Type bound : typeVariable.getBounds()) {
checkType(markedClass, bound);
checkType(bound);
}
} else if (type instanceof WildcardType) {
WildcardType wildcardType = (WildcardType) type;
for (Type bound : wildcardType.getUpperBounds()) {
checkType(markedClass, bound);
checkType(bound);
}
for (Type bound : wildcardType.getLowerBounds()) {
checkType(markedClass, bound);
checkType(bound);
}
}
}

private void checkClass(Set<Type> markedClass, Class<?> clazz) {
private boolean checkClass(Class<?> clazz) {
if (clazz == null) {
return;
return false;
}

if (!markedClass.add(clazz)) {
return;
if (!markedTypeCache.add(clazz)) {
return false;
}

addToAllow(clazz);

if (ClassUtils.isSimpleType(clazz) || clazz.isPrimitive() || clazz.isArray()) {
return;
return true;
}
String className = clazz.getName();
if (className.startsWith("java.")
|| className.startsWith("javax.")
|| className.startsWith("com.sun.")
|| className.startsWith("sun.")
|| className.startsWith("jdk.")) {
return;
return true;
}

Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> interfaceClass : interfaces) {
checkClass(markedClass, interfaceClass);
checkClass(interfaceClass);
}

for (Type genericInterface : clazz.getGenericInterfaces()) {
checkType(markedClass, genericInterface);
checkType(genericInterface);
}

Class<?> superclass = clazz.getSuperclass();
if (superclass != null) {
checkClass(markedClass, superclass);
checkClass(superclass);
}

Type genericSuperclass = clazz.getGenericSuperclass();
if (genericSuperclass != null) {
checkType(markedClass, genericSuperclass);
checkType(genericSuperclass);
}

Field[] fields = clazz.getDeclaredFields();
Expand All @@ -338,9 +341,11 @@ private void checkClass(Set<Type> markedClass, Class<?> clazz) {
}

Class<?> fieldClass = field.getType();
checkClass(markedClass, fieldClass);
checkType(markedClass, field.getGenericType());
checkClass(fieldClass);
checkType(field.getGenericType());
}

return true;
}

private void addToAllow(Class<?> clazz) {
Expand Down
Loading