Skip to content

Commit

Permalink
Merge pull request #1399 from cshannon/buffer-validation
Browse files Browse the repository at this point in the history
AMQ-6596 - Validate size of buffers during unmarshalling
  • Loading branch information
cshannon authored Feb 25, 2025
2 parents 78ee343 + 3037ce8 commit fc4372b
Show file tree
Hide file tree
Showing 74 changed files with 796 additions and 259 deletions.
5 changes: 5 additions & 0 deletions activemq-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@
<artifactId>log4j-slf4j2-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.javassist</groupId>
<artifactId>javassist</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class OpenWireFormat implements WireFormat {
private static final int MARSHAL_CACHE_SIZE = Short.MAX_VALUE / 2;
private static final int MARSHAL_CACHE_FREE_SPACE = 100;

private DataStreamMarshaller dataMarshallers[];
private DataStreamMarshaller[] dataMarshallers;
private int version;
private boolean stackTraceEnabled;
private boolean tcpNoDelayEnabled;
Expand All @@ -61,13 +61,22 @@ public final class OpenWireFormat implements WireFormat {
// The following fields are used for value caching
private short nextMarshallCacheIndex;
private short nextMarshallCacheEvictionIndex;
private Map<DataStructure, Short> marshallCacheMap = new HashMap<DataStructure, Short>();
private Map<DataStructure, Short> marshallCacheMap = new HashMap<>();
private DataStructure marshallCache[] = null;
private DataStructure unmarshallCache[] = null;
private DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
private DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
private final DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
private final DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
private WireFormatInfo preferedWireFormatInfo;

// Used to track the currentFrameSize for validation during unmarshalling
// Ideally we would pass the MarshallingContext directly to the marshalling methods,
// however this would require modifying the DataStreamMarshaller interface which would result
// in hundreds of existing methods having to be updated so this allows avoiding that and
// tracking the state without breaking the existing API.
// Note that while this is currently only used during unmarshalling, but if necessary could
// be extended in the future to be used during marshalling as well.
private final ThreadLocal<MarshallingContext> marshallingContext = new ThreadLocal<>();

public OpenWireFormat() {
this(DEFAULT_STORE_VERSION);
}
Expand Down Expand Up @@ -191,26 +200,23 @@ public synchronized ByteSequence marshal(Object command) throws IOException {
@Override
public synchronized Object unmarshal(ByteSequence sequence) throws IOException {
bytesIn.restart(sequence);
// DataInputStream dis = new DataInputStream(new
// ByteArrayInputStream(sequence));

if (!sizePrefixDisabled) {
int size = bytesIn.readInt();
if (sequence.getLength() - 4 != size) {
// throw new IOException("Packet size does not match marshaled
// size");
}

if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
try {
final var context = new MarshallingContext();
marshallingContext.set(context);

if (!sizePrefixDisabled) {
int size = bytesIn.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
}
context.setFrameSize(size);
}
return doUnmarshal(bytesIn);
} finally {
// After we unmarshal we can clear the context
marshallingContext.remove();
}

Object command = doUnmarshal(bytesIn);
// if( !cacheEnabled && ((DataStructure)command).isMarshallAware() ) {
// ((MarshallAware) command).setCachedMarshalledForm(this, sequence);
// }
return command;
}

@Override
Expand Down Expand Up @@ -275,19 +281,22 @@ public synchronized void marshal(Object o, DataOutput dataOut) throws IOExceptio

@Override
public Object unmarshal(DataInput dis) throws IOException {
DataInput dataIn = dis;
if (!sizePrefixDisabled) {
int size = dis.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
try {
final var context = new MarshallingContext();
marshallingContext.set(context);

if (!sizePrefixDisabled) {
int size = dis.readInt();
if (maxFrameSizeEnabled && size > maxFrameSize) {
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
}
context.setFrameSize(size);
}
// int size = dis.readInt();
// byte[] data = new byte[size];
// dis.readFully(data);
// bytesIn.restart(data);
// dataIn = bytesIn;
return doUnmarshal(dis);
} finally {
// After we unmarshal we can clear
marshallingContext.remove();
}
return doUnmarshal(dataIn);
}

/**
Expand Down Expand Up @@ -363,7 +372,7 @@ public void setVersion(int version) {
this.version = version;
}

public Object doUnmarshal(DataInput dis) throws IOException {
private Object doUnmarshal(DataInput dis) throws IOException {
byte dataType = dis.readByte();
if (dataType != NULL_TYPE) {
DataStreamMarshaller dsm = dataMarshallers[dataType & 0xFF];
Expand Down Expand Up @@ -698,4 +707,47 @@ protected long min(long version1, long version2) {
}
return version2;
}

MarshallingContext getMarshallingContext() {
return marshallingContext.get();
}

// Used to track the estimated allocated buffer sizes to validate
// against the current frame being processed
static class MarshallingContext {
// Use primitives to minimize memory footprint
private int frameSize = -1;
private int estimatedAllocated = 0;

void setFrameSize(int frameSize) throws IOException {
this.frameSize = frameSize;
if (frameSize < 0) {
throw error("Frame size " + frameSize + " can't be negative.");
}
}

void increment(int size) throws IOException {
if (size < 0) {
throw error("Size " + size + " can't be negative.");
}
try {
estimatedAllocated = Math.addExact(estimatedAllocated, size);
} catch (ArithmeticException e) {
throw error("Buffer overflow when incrementing size value: " + size);
}
}

public int getFrameSize() {
return frameSize;
}

public int getEstimatedAllocated() {
return estimatedAllocated;
}

private static IOException error(String errorMessage) {
return new IOException(new IllegalArgumentException(errorMessage));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
*/
package org.apache.activemq.openwire;

import java.io.IOException;
import org.apache.activemq.util.IOExceptionSupport;

public class OpenWireUtil {

private static final String jmsPackageToReplace = "javax.jms";
private static final String jmsPackageToUse = "jakarta.jms";
static final String jmsPackageToReplace = "javax.jms";
static final String jmsPackageToUse = "jakarta.jms";

/**
* Verify that the provided class extends {@link Throwable} and throw an
Expand All @@ -33,6 +36,50 @@ public static void validateIsThrowable(Class<?> clazz) {
}
}

/**
* Verify that the buffer size that will be allocated will not push the total allocated
* size of this frame above the expected frame size. This is an estimate as the current
* size is only tracked when calls to this method are made and is primarily intended
* to prevent large arrays from being created due to an invalid size.
*
* Also verify the size against configured max frame size.
* This check is a sanity check in case of corrupt packets contain invalid size values.
*
* @param wireFormat configured OpenWireFormat
* @param size buffer size to verify
* @throws IOException If size is larger than currentFrameSize or maxFrameSize
*/
public static void validateBufferSize(OpenWireFormat wireFormat, int size) throws IOException {
validateLessThanFrameSize(wireFormat, size);

// if currentFrameSize is set and was checked above then this check should not be needed,
// but it doesn't hurt to verify again in case the max frame size check was missed
// somehow
if (wireFormat.isMaxFrameSizeEnabled() && size > wireFormat.getMaxFrameSize()) {
throw IOExceptionSupport.createFrameSizeException(size, wireFormat.getMaxFrameSize());
}
}

// Verify total tracked sizes will not exceed the overall size of the frame
private static void validateLessThanFrameSize(OpenWireFormat wireFormat, int size)
throws IOException {
final var context = wireFormat.getMarshallingContext();
// No information on current frame size so just return
if (context == null || context.getFrameSize() < 0) {
return;
}

// Increment existing estimated buffer size with new size
context.increment(size);

// We should never be trying to allocate a buffer that is going to push the total
// size greater than the entire frame itself
if (context.getEstimatedAllocated() > context.getFrameSize()) {
throw IOExceptionSupport.createFrameSizeBufferException(
context.getEstimatedAllocated(), context.getFrameSize());
}
}

/**
* This method can be used to convert from javax -> jakarta or
* vice versa depending on the version used by the client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,11 @@ protected void tightMarshalByteArray2(byte[] data, DataOutput dataOut, BooleanSt
}
}

protected byte[] tightUnmarshalByteArray(DataInput dataIn, BooleanStream bs) throws IOException {
protected byte[] tightUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
byte rc[] = null;
if (bs.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
rc = new byte[size];
dataIn.readFully(rc);
}
Expand All @@ -438,10 +439,11 @@ protected void tightMarshalByteSequence2(ByteSequence data, DataOutput dataOut,
}
}

protected ByteSequence tightUnmarshalByteSequence(DataInput dataIn, BooleanStream bs) throws IOException {
protected ByteSequence tightUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
ByteSequence rc = null;
if (bs.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
byte[] t = new byte[size];
dataIn.readFully(t);
return new ByteSequence(t, 0, size);
Expand Down Expand Up @@ -618,10 +620,11 @@ protected void looseMarshalByteArray(OpenWireFormat wireFormat, byte[] data, Dat
}
}

protected byte[] looseUnmarshalByteArray(DataInput dataIn) throws IOException {
protected byte[] looseUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
byte rc[] = null;
if (dataIn.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
rc = new byte[size];
dataIn.readFully(rc);
}
Expand All @@ -637,10 +640,11 @@ protected void looseMarshalByteSequence(OpenWireFormat wireFormat, ByteSequence
}
}

protected ByteSequence looseUnmarshalByteSequence(DataInput dataIn) throws IOException {
protected ByteSequence looseUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
ByteSequence rc = null;
if (dataIn.readBoolean()) {
int size = dataIn.readInt();
OpenWireUtil.validateBufferSize(wireFormat, size);
byte[] t = new byte[size];
dataIn.readFully(t);
rc = new ByteSequence(t, 0, size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
info.setTimestamp(tightUnmarshalLong(wireFormat, dataIn, bs));
info.setType(tightUnmarshalString(dataIn, bs));
info.setContent(tightUnmarshalByteSequence(dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
info.setContent(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
info.setDataStructure((org.apache.activemq.command.DataStructure)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)tightUnmarsalCachedObject(wireFormat, dataIn, bs));
info.setCompressed(bs.readBoolean());
Expand Down Expand Up @@ -196,8 +196,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)looseUnmarsalNestedObject(wireFormat, dataIn));
info.setTimestamp(looseUnmarshalLong(wireFormat, dataIn));
info.setType(looseUnmarshalString(dataIn));
info.setContent(looseUnmarshalByteSequence(dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
info.setContent(looseUnmarshalByteSequence(wireFormat, dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));
info.setDataStructure((org.apache.activemq.command.DataStructure)looseUnmarsalNestedObject(wireFormat, dataIn));
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)looseUnmarsalCachedObject(wireFormat, dataIn));
info.setCompressed(dataIn.readBoolean());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

PartialCommand info = (PartialCommand)o;
info.setCommandId(dataIn.readInt());
info.setData(tightUnmarshalByteArray(dataIn, bs));
info.setData(tightUnmarshalByteArray(wireFormat, dataIn, bs));

}

Expand Down Expand Up @@ -114,7 +114,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

PartialCommand info = (PartialCommand)o;
info.setCommandId(dataIn.readInt());
info.setData(looseUnmarshalByteArray(dataIn));
info.setData(looseUnmarshalByteArray(wireFormat, dataIn));

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

info.setMagic(tightUnmarshalConstByteArray(dataIn, bs, 8));
info.setVersion(dataIn.readInt());
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));

info.afterUnmarshall(wireFormat);

Expand Down Expand Up @@ -130,7 +130,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

info.setMagic(looseUnmarshalConstByteArray(dataIn, 8));
info.setVersion(dataIn.readInt());
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));

info.afterUnmarshall(wireFormat);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

XATransactionId info = (XATransactionId)o;
info.setFormatId(dataIn.readInt());
info.setGlobalTransactionId(tightUnmarshalByteArray(dataIn, bs));
info.setBranchQualifier(tightUnmarshalByteArray(dataIn, bs));
info.setGlobalTransactionId(tightUnmarshalByteArray(wireFormat, dataIn, bs));
info.setBranchQualifier(tightUnmarshalByteArray(wireFormat, dataIn, bs));

}

Expand Down Expand Up @@ -117,8 +117,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn

XATransactionId info = (XATransactionId)o;
info.setFormatId(dataIn.readInt());
info.setGlobalTransactionId(looseUnmarshalByteArray(dataIn));
info.setBranchQualifier(looseUnmarshalByteArray(dataIn));
info.setGlobalTransactionId(looseUnmarshalByteArray(wireFormat, dataIn));
info.setBranchQualifier(looseUnmarshalByteArray(wireFormat, dataIn));

}

Expand Down
Loading

0 comments on commit fc4372b

Please sign in to comment.