Skip to content

Commit

Permalink
ARTEMIS-4167 Enhanced deserialization filter
Browse files Browse the repository at this point in the history
Now that Artemis is Java 11+ compatible, there is now the ability to
set an ObjectInputFilter on an ObjectInputStream. This allows us to
write more advanced filter patterns or plug-in our own object filter
implementation.
  • Loading branch information
swerner0 committed Oct 18, 2024
1 parent c78d87d commit 4dcbd6c
Show file tree
Hide file tree
Showing 18 changed files with 1,150 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ public String getDeserializationAllowList() {
return this.factoryReference.getDeserializationAllowList();
}

public String getSerialFilter() {
return this.factoryReference.getSerialFilter();
}

public String getSerialFilterClassName() {
return this.factoryReference.getSerialFilterClassName();
}

private static class JMSFailureListener implements SessionFailureListener {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ public class ActiveMQConnectionFactory extends JNDIStorable implements Connectio

private String deserializationAllowList;

private String serialFilter;

private String serialFilterClassName;

private boolean cacheDestinations;

// keeping this field for serialization compatibility only. do not use it
Expand Down Expand Up @@ -209,6 +213,26 @@ public void setDeserializationAllowList(String allowList) {
this.deserializationAllowList = allowList;
}

@Override
public String getSerialFilter() {
return serialFilter;
}

@Override
public void setSerialFilter(String serialFilter) {
this.serialFilter = serialFilter;
}

@Override
public String getSerialFilterClassName() {
return serialFilterClassName;
}

@Override
public void setSerialFilterClassName(String serialFilterClassName) {
this.serialFilterClassName = serialFilterClassName;
}

@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
String url = in.readUTF();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import javax.jms.ObjectMessage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputFilter;
import java.io.ObjectOutputStream;
import java.io.Serializable;

Expand Down Expand Up @@ -145,6 +146,12 @@ public Serializable getObject() throws JMSException {
if (allowList != null) {
ois.setAllowList(allowList);
}

ObjectInputFilter oif = ObjectInputFilterFactory.getObjectInputFilter(options);
if (oif != null) {
ois.setObjectInputFilter(oif);
}

Serializable object = (Serializable) ois.readObject();
return object;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,14 @@ public String getDeserializationAllowList() {
return connection.getDeserializationAllowList();
}

public String getSerialFilter() {
return connection.getSerialFilter();
}

public String getSerialFilterClassName() {
return connection.getSerialFilterClassName();
}

enum ConsumerDurability {
DURABLE, NON_DURABLE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,11 @@ public interface ConnectionFactoryOptions {

void setDeserializationAllowList(String allowList);

String getSerialFilter();

void setSerialFilter(String serialFilter);

String getSerialFilterClassName();

void setSerialFilterClassName(String serialFilterClassName);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.jms.client;

import java.io.ObjectInputFilter;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

public class ObjectInputFilterFactory {

public static final String SERIAL_FILTER_PROPERTY = "org.apache.activemq.artemis.jms.serialFilter";
public static final String SERIAL_FILTER_CLASS_NAME_PROPERTY = "org.apache.activemq.artemis.jms.serialFilterClassName";

private static Map<Object, ObjectInputFilter> filterCache = new ConcurrentHashMap<>();

public static ObjectInputFilter getObjectInputFilter(ConnectionFactoryOptions options) {
String className = getFilterClassName(options);
if (className != null) {
return getObjectInputFilterForClassName(className);
}

String pattern = getFilterPattern(options);
if (pattern != null) {
return getObjectInputFilterForPattern(pattern);
}

return null;
}

public static ObjectInputFilter getObjectInputFilterForPattern(String pattern) {
if (pattern == null) {
return null;
}

return filterCache.computeIfAbsent(new PatternKey(pattern),
k -> ObjectInputFilter.Config.createFilter(pattern));
}

public static ObjectInputFilter getObjectInputFilterForClassName(String className) {
if (className == null) {
return null;
}

return filterCache.computeIfAbsent(new ClassNameKey(className), k -> {
try {
return (ObjectInputFilter)Class.forName(className).newInstance();
} catch (ClassNotFoundException e) {
throw new RuntimeException("Class " + className + " not found.", e);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}

private static String getFilterClassName(ConnectionFactoryOptions options) {
if (options != null && options.getSerialFilterClassName() != null) {
return options.getSerialFilterClassName();
}

return System.getProperty(SERIAL_FILTER_CLASS_NAME_PROPERTY);
}

private static String getFilterPattern(ConnectionFactoryOptions options) {
if (options != null && options.getSerialFilter() != null) {
return options.getSerialFilter();
}

return System.getProperty(SERIAL_FILTER_PROPERTY);
}

private static class PatternKey {
private String pattern;

PatternKey(String pattern) {
this.pattern = pattern;
}

public String getPattern() {
return pattern;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

PatternKey that = (PatternKey) o;

if (!Objects.equals(pattern, that.pattern)) return false;

return true;
}

@Override
public int hashCode() {
return pattern != null ? pattern.hashCode() : 0;
}
}

private static class ClassNameKey {
private String className;

ClassNameKey(String className) {
this.className = className;
}

public String getClassName() {
return className;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

ClassNameKey that = (ClassNameKey) o;

if (!Objects.equals(className, that.className)) return false;

return true;
}

@Override
public int hashCode() {
return className != null ? className.hashCode() : 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ public interface ConnectionFactoryConfiguration extends EncodingSupport {

void setDeserializationAllowList(String allowList);

String getSerialFilter();

void setSerialFilter(String serialFilter);

String getSerialFilterClassName();

void setSerialFilterClassName(String serialFilterClassName);

int getInitialMessagePacketSize();

ConnectionFactoryConfiguration setInitialMessagePacketSize(int size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ public class ConnectionFactoryConfigurationImpl implements ConnectionFactoryConf

private String deserializationAllowList;

private String serialFilter;

private String serialFilterClassName;

private int initialMessagePacketSize = ActiveMQClient.DEFAULT_INITIAL_MESSAGE_PACKET_SIZE;

private boolean enable1xPrefixes = ActiveMQJMSClient.DEFAULT_ENABLE_1X_PREFIXES;
Expand Down Expand Up @@ -658,6 +662,9 @@ public void decode(final ActiveMQBuffer buffer) {

compressionLevel = buffer.readableBytes() > 0 ? BufferHelper.readNullableInteger(buffer) : ActiveMQClient.DEFAULT_COMPRESSION_LEVEL;

serialFilter = buffer.readableBytes() > 0 ? BufferHelper.readNullableSimpleStringAsString(buffer) : null;

serialFilterClassName = buffer.readableBytes() > 0 ? BufferHelper.readNullableSimpleStringAsString(buffer) : null;
}

@Override
Expand Down Expand Up @@ -756,6 +763,10 @@ public void encode(final ActiveMQBuffer buffer) {
BufferHelper.writeNullableBoolean(buffer, useTopologyForLoadBalancing);

BufferHelper.writeNullableInteger(buffer, compressionLevel);

BufferHelper.writeAsNullableSimpleString(buffer, serialFilter);

BufferHelper.writeAsNullableSimpleString(buffer, serialFilterClassName);
}

@Override
Expand Down Expand Up @@ -878,7 +889,11 @@ public int getEncodeSize() {

BufferHelper.sizeOfNullableBoolean(useTopologyForLoadBalancing) +

BufferHelper.sizeOfNullableInteger(compressionLevel);
BufferHelper.sizeOfNullableInteger(compressionLevel) +

BufferHelper.sizeOfNullableSimpleString(serialFilter) +

BufferHelper.sizeOfNullableSimpleString(serialFilterClassName);

return size;
}
Expand Down Expand Up @@ -934,6 +949,26 @@ public void setDeserializationAllowList(String allowList) {
this.deserializationAllowList = allowList;
}

@Override
public String getSerialFilter() {
return serialFilter;
}

@Override
public void setSerialFilter(String serialFilter) {
this.serialFilter = serialFilter;
}

@Override
public String getSerialFilterClassName() {
return serialFilterClassName;
}

@Override
public void setSerialFilterClassName(String serialFilterClassName) {
this.serialFilterClassName = serialFilterClassName;
}

@Override
public ConnectionFactoryConfiguration setCompressLargeMessages(boolean compressLargeMessage) {
this.compressLargeMessage = compressLargeMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,8 @@ protected ActiveMQConnectionFactory internalCreateCFPOJO(final ConnectionFactory
cf.setProtocolManagerFactoryStr(cfConfig.getProtocolManagerFactoryStr());
cf.setDeserializationDenyList(cfConfig.getDeserializationDenyList());
cf.setDeserializationAllowList(cfConfig.getDeserializationAllowList());
cf.setSerialFilter(cfConfig.getSerialFilter());
cf.setSerialFilterClassName(cfConfig.getSerialFilterClassName());
cf.setInitialMessagePacketSize(cfConfig.getInitialMessagePacketSize());
cf.setEnable1xPrefixes(cfConfig.isEnable1xPrefixes());
cf.setEnableSharedClientID(cfConfig.isEnableSharedClientID());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,30 @@ public void setDeserializationAllowList(String deserializationAllowList) {
raProperties.setDeserializationAllowList(deserializationAllowList);
}

public String getSerialFilter() {
logger.trace("getSerialFilter()");

return raProperties.getSerialFilter();
}

public void setSerialFilter(String serialFilter) {
logger.trace("setSerialFilter({})", serialFilter);

raProperties.setSerialFilter(serialFilter);
}

public String getSerialFilterClassName() {
logger.trace("getSerialFilterClassName()");

return raProperties.getSerialFilterClassName();
}

public void setSerialFilterClassName(String serialFilterClassName) {
logger.trace("setSerialFilterClassName({})", serialFilterClassName);

raProperties.setSerialFilterClassName(serialFilterClassName);
}

/**
* Get min large message size
*
Expand Down Expand Up @@ -1944,6 +1968,14 @@ private void setParams(final ActiveMQConnectionFactory cf, final ConnectionFacto
if (stringVal != null) {
cf.setDeserializationAllowList(stringVal);
}
stringVal = overrideProperties.getSerialFilter() != null ? overrideProperties.getSerialFilter() : raProperties.getSerialFilter();
if (stringVal != null) {
cf.setSerialFilter(stringVal);
}
stringVal = overrideProperties.getSerialFilterClassName() != null ? overrideProperties.getSerialFilterClassName() : raProperties.getSerialFilterClassName();
if (stringVal != null) {
cf.setSerialFilterClassName(stringVal);
}

cf.setIgnoreJTA(isIgnoreJTA());
}
Expand Down
Loading

0 comments on commit 4dcbd6c

Please sign in to comment.