Skip to content

Commit

Permalink
AMQ-6596 - Validate size of buffers during unmarshalling
Browse files Browse the repository at this point in the history
Verify that size buffers for arrays and bytesequences will not exceed
the overall frame size
  • Loading branch information
cshannon committed Feb 25, 2025
1 parent 78ee343 commit 3037ce8
Show file tree
Hide file tree
Showing 74 changed files with 796 additions and 259 deletions.
5 changes: 5 additions & 0 deletions activemq-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@
<artifactId>log4j-slf4j2-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.javassist</groupId>
<artifactId>javassist</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class OpenWireFormat implements WireFormat {
private static final int MARSHAL_CACHE_SIZE = Short.MAX_VALUE / 2;
private static final int MARSHAL_CACHE_FREE_SPACE = 100;

private DataStreamMarshaller dataMarshallers[];
private DataStreamMarshaller[] dataMarshallers;
private int version;
private boolean stackTraceEnabled;
private boolean tcpNoDelayEnabled;
Expand All @@ -61,13 +61,22 @@ public final class OpenWireFormat implements WireFormat {
// The following fields are used for value caching
private short nextMarshallCacheIndex;
private short nextMarshallCacheEvictionIndex;
private Map<DataStructure, Short> marshallCacheMap = new HashMap<DataStructure, Short>();
private Map<DataStructure, Short> marshallCacheMap = new HashMap<>();
private DataStructure marshallCache[] = null;
private DataStructure unmarshallCache[] = null;
private DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
private DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
private final DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
private final DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
private WireFormatInfo preferedWireFormatInfo;

// Used to track the currentFrameSize for validation during unmarshalling
// Ideally we would pass the MarshallingContext directly to the marshalling methods,
// however this would require modifying the DataStreamMarshaller interface which would result
// in hundreds of existing methods having to be updated so this allows avoiding that and
// tracking the state without breaking the existing API.
// Note that while this is currently only used during unmarshalling, but if necessary could
// be extended in the future to be used during marshalling as well.
private final ThreadLocal<MarshallingContext> marshallingContext = new ThreadLocal<>();

public OpenWireFormat() {
this(DEFAULT_STORE_VERSION);
}
Expand Down Expand Up @@ -191,26 +200,23 @@ public synchronized ByteSequence marshal(Object command) throws IOException {
@Override
public synchronized Object unmarshal(ByteSequence sequence) throws IOException {
bytesIn.restart(sequence);
// DataInputStream dis = new DataInputStream(new
// ByteArrayInputStream(sequence));

if (!sizePrefixDisabled) {
int size = bytesIn.readInt();
if (sequence.getLength() - 4 != size) {
// throw new IOException("Packet size does not match marshaled
// size");
}

if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
try {
final var context = new MarshallingContext();
marshallingContext.set(context);

if (!sizePrefixDisabled) {
int size = bytesIn.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
}
context.setFrameSize(size);
}
return doUnmarshal(bytesIn);
} finally {
// After we unmarshal we can clear the context
marshallingContext.remove();
}

Object command = doUnmarshal(bytesIn);
// if( !cacheEnabled && ((DataStructure)command).isMarshallAware() ) {
// ((MarshallAware) command).setCachedMarshalledForm(this, sequence);
// }
return command;
}

@Override
Expand Down Expand Up @@ -275,19 +281,22 @@ public synchronized void marshal(Object o, DataOutput dataOut) throws IOExceptio

@Override
public Object unmarshal(DataInput dis) throws IOException {
DataInput dataIn = dis;
if (!sizePrefixDisabled) {
int size = dis.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
try {
final var context = new MarshallingContext();
marshallingContext.set(context);

if (!sizePrefixDisabled) {
int size = dis.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
}
context.setFrameSize(size);
}
// int size = dis.readInt();
// byte[] data = new byte[size];
// dis.readFully(data);
// bytesIn.restart(data);
// dataIn = bytesIn;
return doUnmarshal(dis);
} finally {
// After we unmarshal we can clear
marshallingContext.remove();
}
return doUnmarshal(dataIn);
}

/**
Expand Down Expand Up @@ -363,7 +372,7 @@ public void setVersion(int version) {
this.version = version;
}

public Object doUnmarshal(DataInput dis) throws IOException {
private Object doUnmarshal(DataInput dis) throws IOException {
byte dataType = dis.readByte();
if (dataType != NULL_TYPE) {
DataStreamMarshaller dsm = dataMarshallers[dataType & 0xFF];
Expand Down Expand Up @@ -698,4 +707,47 @@ protected long min(long version1, long version2) {
}
return version2;
}

MarshallingContext getMarshallingContext() {
return marshallingContext.get();
}

// Used to track the estimated allocated buffer sizes to validate
// against the current frame being processed
static class MarshallingContext {
// Use primitives to minimize memory footprint
private int frameSize = -1;
private int estimatedAllocated = 0;

void setFrameSize(int frameSize) throws IOException {
this.frameSize = frameSize;
if (frameSize < 0) {
throw error("Frame size " + frameSize + " can't be negative.");
}
}

void increment(int size) throws IOException {
if (size < 0) {
throw error("Size " + size + " can't be negative.");
}
try {
estimatedAllocated = Math.addExact(estimatedAllocated, size);
} catch (ArithmeticException e) {
throw error("Buffer overflow when incrementing size value: " + size);
}
}

public int getFrameSize() {
return frameSize;
}

public int getEstimatedAllocated() {
return estimatedAllocated;
}

private static IOException error(String errorMessage) {
return new IOException(new IllegalArgumentException(errorMessage));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
*/
package org.apache.activemq.openwire;

import java.io.IOException;
import org.apache.activemq.util.IOExceptionSupport;

public class OpenWireUtil {

private static final String jmsPackageToReplace = "javax.jms";
private static final String jmsPackageToUse = "jakarta.jms";
static final String jmsPackageToReplace = "javax.jms";
static final String jmsPackageToUse = "jakarta.jms";

/**
* Verify that the provided class extends {@link Throwable} and throw an
Expand All @@ -33,6 +36,50 @@ public static void validateIsThrowable(Class<?> clazz) {
}
}

/**
* Verify that the buffer size that will be allocated will not push the total allocated
* size of this frame above the expected frame size. This is an estimate as the current
* size is only tracked when calls to this method are made and is primarily intended
* to prevent large arrays from being created due to an invalid size.
*
* Also verify the size against configured max frame size.
* This check is a sanity check in case of corrupt packets contain invalid size values.
*
* @param wireFormat configured OpenWireFormat
* @param size buffer size to verify
* @throws IOException If size is larger than currentFrameSize or maxFrameSize
*/
public static void validateBufferSize(OpenWireFormat wireFormat, int size) throws IOException {
validateLessThanFrameSize(wireFormat, size);

// if currentFrameSize is set and was checked above then this check should not be needed,
// but it doesn't hurt to verify again in case the max frame size check was missed
// somehow
if (wireFormat.isMaxFrameSizeEnabled() && size > wireFormat.getMaxFrameSize()) {
throw IOExceptionSupport.createFrameSizeException(size, wireFormat.getMaxFrameSize());
}
}

// Verify total tracked sizes will not exceed the overall size of the frame
private static void validateLessThanFrameSize(OpenWireFormat wireFormat, int size)
throws IOException {
final var context = wireFormat.getMarshallingContext();
// No information on current frame size so just return
if (context == null || context.getFrameSize() < 0) {
return;
}

// Increment existing estimated buffer size with new size
context.increment(size);

// We should never be trying to allocate a buffer that is going to push the total
// size greater than the entire frame itself
if (context.getEstimatedAllocated() > context.getFrameSize()) {
throw IOExceptionSupport.createFrameSizeBufferException(
context.getEstimatedAllocated(), context.getFrameSize());
}
}

/**
* This method can be used to convert from javax -> jakarta or
* vice versa depending on the version used by the client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ protected void tightMarshalByteArray2(byte[] data, DataOutput dataOut, BooleanSt
}
}

protected byte[] tightUnmarshalByteArray(DataInput dataIn, BooleanStream bs) throws IOException {
protected byte[] tightUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
byte rc[] = null;
if (bs.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
rc = new byte[size];
dataIn.readFully(rc);
}
Expand All @@ -438,10 +439,11 @@ protected void tightMarshalByteSequence2(ByteSequence data, DataOutput dataOut,
}
}

protected ByteSequence tightUnmarshalByteSequence(DataInput dataIn, BooleanStream bs) throws IOException {
protected ByteSequence tightUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
ByteSequence rc = null;
if (bs.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
byte[] t = new byte[size];
dataIn.readFully(t);
return new ByteSequence(t, 0, size);
Expand Down Expand Up @@ -618,10 +620,11 @@ protected void looseMarshalByteArray(OpenWireFormat wireFormat, byte[] data, Dat
}
}

protected byte[] looseUnmarshalByteArray(DataInput dataIn) throws IOException {
protected byte[] looseUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
byte rc[] = null;
if (dataIn.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
rc = new byte[size];
dataIn.readFully(rc);
}
Expand All @@ -637,10 +640,11 @@ protected void looseMarshalByteSequence(OpenWireFormat wireFormat, ByteSequence
}
}

protected ByteSequence looseUnmarshalByteSequence(DataInput dataIn) throws IOException {
protected ByteSequence looseUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
ByteSequence rc = null;
if (dataIn.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
byte[] t = new byte[size];
dataIn.readFully(t);
rc = new ByteSequence(t, 0, size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
info.setTimestamp(tightUnmarshalLong(wireFormat, dataIn, bs));
info.setType(tightUnmarshalString(dataIn, bs));
info.setContent(tightUnmarshalByteSequence(dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
info.setContent(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
info.setDataStructure((org.apache.activemq.command.DataStructure)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)tightUnmarsalCachedObject(wireFormat, dataIn, bs));
info.setCompressed(bs.readBoolean());
Expand Down Expand Up @@ -196,8 +196,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)looseUnmarsalNestedObject(wireFormat, dataIn));
info.setTimestamp(looseUnmarshalLong(wireFormat, dataIn));
info.setType(looseUnmarshalString(dataIn));
info.setContent(looseUnmarshalByteSequence(dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
info.setContent(looseUnmarshalByteSequence(wireFormat, dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));
info.setDataStructure((org.apache.activemq.command.DataStructure)looseUnmarsalNestedObject(wireFormat, dataIn));
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)looseUnmarsalCachedObject(wireFormat, dataIn));
info.setCompressed(dataIn.readBoolean());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

PartialCommand info = (PartialCommand)o;
info.setCommandId(dataIn.readInt());
info.setData(tightUnmarshalByteArray(dataIn, bs));
info.setData(tightUnmarshalByteArray(wireFormat, dataIn, bs));

}

Expand Down Expand Up @@ -114,7 +114,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

PartialCommand info = (PartialCommand)o;
info.setCommandId(dataIn.readInt());
info.setData(looseUnmarshalByteArray(dataIn));
info.setData(looseUnmarshalByteArray(wireFormat, dataIn));

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

info.setMagic(tightUnmarshalConstByteArray(dataIn, bs, 8));
info.setVersion(dataIn.readInt());
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));

info.afterUnmarshall(wireFormat);

Expand Down Expand Up @@ -130,7 +130,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

info.setMagic(looseUnmarshalConstByteArray(dataIn, 8));
info.setVersion(dataIn.readInt());
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));

info.afterUnmarshall(wireFormat);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

XATransactionId info = (XATransactionId)o;
info.setFormatId(dataIn.readInt());
info.setGlobalTransactionId(tightUnmarshalByteArray(dataIn, bs));
info.setBranchQualifier(tightUnmarshalByteArray(dataIn, bs));
info.setGlobalTransactionId(tightUnmarshalByteArray(wireFormat, dataIn, bs));
info.setBranchQualifier(tightUnmarshalByteArray(wireFormat, dataIn, bs));

}

Expand Down Expand Up @@ -117,8 +117,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

XATransactionId info = (XATransactionId)o;
info.setFormatId(dataIn.readInt());
info.setGlobalTransactionId(looseUnmarshalByteArray(dataIn));
info.setBranchQualifier(looseUnmarshalByteArray(dataIn));
info.setGlobalTransactionId(looseUnmarshalByteArray(wireFormat, dataIn));
info.setBranchQualifier(looseUnmarshalByteArray(wireFormat, dataIn));

}

Expand Down
Loading

0 comments on commit 3037ce8

Please sign in to comment.