Skip to content

Commit 4728531

Browse files
protobuf-github-botzhangskz
authored andcommitted
Add recursion check when parsing unknown fields in Java.
PiperOrigin-RevId: 675657198
1 parent 850fcce commit 4728531

File tree

7 files changed

+456
-12
lines changed

7 files changed

+456
-12
lines changed

java/core/src/main/java/com/google/protobuf/ArrayDecoders.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
*/
2424
@CheckReturnValue
2525
final class ArrayDecoders {
26+
static final int DEFAULT_RECURSION_LIMIT = 100;
27+
28+
@SuppressWarnings("NonFinalStaticField")
29+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
2630

2731
private ArrayDecoders() {}
2832

@@ -37,6 +41,7 @@ static final class Registers {
3741
public long long1;
3842
public Object object1;
3943
public final ExtensionRegistryLite extensionRegistry;
44+
public int recursionDepth;
4045

4146
Registers() {
4247
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
@@ -244,7 +249,10 @@ static int mergeMessageField(
244249
if (length < 0 || length > limit - position) {
245250
throw InvalidProtocolBufferException.truncatedMessage();
246251
}
252+
registers.recursionDepth++;
253+
checkRecursionLimit(registers.recursionDepth);
247254
schema.mergeFrom(msg, data, position, position + length, registers);
255+
registers.recursionDepth--;
248256
registers.object1 = msg;
249257
return position + length;
250258
}
@@ -262,8 +270,11 @@ static int mergeGroupField(
262270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
263271
// and it can't be used in group fields).
264272
final MessageSchema messageSchema = (MessageSchema) schema;
273+
registers.recursionDepth++;
274+
checkRecursionLimit(registers.recursionDepth);
265275
final int endPosition =
266276
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
277+
registers.recursionDepth--;
267278
registers.object1 = msg;
268279
return endPosition;
269280
}
@@ -1024,6 +1035,8 @@ static int decodeUnknownField(
10241035
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
10251036
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
10261037
int lastTag = 0;
1038+
registers.recursionDepth++;
1039+
checkRecursionLimit(registers.recursionDepth);
10271040
while (position < limit) {
10281041
position = decodeVarint32(data, position, registers);
10291042
lastTag = registers.int1;
@@ -1032,6 +1045,7 @@ static int decodeUnknownField(
10321045
}
10331046
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
10341047
}
1048+
registers.recursionDepth--;
10351049
if (position > limit || lastTag != endGroup) {
10361050
throw InvalidProtocolBufferException.parseFailure();
10371051
}
@@ -1078,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
10781092
throw InvalidProtocolBufferException.invalidTag();
10791093
}
10801094
}
1095+
1096+
/**
1097+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098+
* the depth of the message exceeds this limit.
1099+
*/
1100+
public static void setRecursionLimit(int limit) {
1101+
recursionLimit = limit;
1102+
}
1103+
1104+
private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
1105+
if (depth >= recursionLimit) {
1106+
throw InvalidProtocolBufferException.recursionLimitExceeded();
1107+
}
1108+
}
10811109
}

java/core/src/main/java/com/google/protobuf/CodedInputStream.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ public void skipMessage() throws IOException {
229229
if (tag == 0) {
230230
return;
231231
}
232+
checkRecursionLimit();
233+
++recursionDepth;
232234
boolean fieldSkipped = skipField(tag);
235+
--recursionDepth;
233236
if (!fieldSkipped) {
234237
return;
235238
}
@@ -246,7 +249,10 @@ public void skipMessage(CodedOutputStream output) throws IOException {
246249
if (tag == 0) {
247250
return;
248251
}
252+
checkRecursionLimit();
253+
++recursionDepth;
249254
boolean fieldSkipped = skipField(tag, output);
255+
--recursionDepth;
250256
if (!fieldSkipped) {
251257
return;
252258
}

java/core/src/main/java/com/google/protobuf/MessageSchema.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,8 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
30063006
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
30073007
}
30083008
// Unknown field.
3009-
3010-
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3009+
if (unknownFieldSchema.mergeOneFieldFrom(
3010+
unknownFields, reader, /* currentDepth= */ 0)) {
30113011
continue;
30123012
}
30133013
}
@@ -3382,8 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33823382
if (unknownFields == null) {
33833383
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33843384
}
3385-
3386-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3385+
if (!unknownFieldSchema.mergeOneFieldFrom(
3386+
unknownFields, reader, /* currentDepth= */ 0)) {
33873387
return;
33883388
}
33893389
break;
@@ -3399,8 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33993399
if (unknownFields == null) {
34003400
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
34013401
}
3402-
3403-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3402+
if (!unknownFieldSchema.mergeOneFieldFrom(
3403+
unknownFields, reader, /* currentDepth= */ 0)) {
34043404
return;
34053405
}
34063406
}

java/core/src/main/java/com/google/protobuf/MessageSetSchema.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
278278
reader, extension, extensionRegistry, extensions);
279279
return true;
280280
} else {
281-
282-
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
281+
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
283282
}
284283
} else {
285284
return reader.skipField();

java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
@CheckReturnValue
1414
abstract class UnknownFieldSchema<T, B> {
1515

16+
static final int DEFAULT_RECURSION_LIMIT = 100;
17+
18+
@SuppressWarnings("NonFinalStaticField")
19+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
20+
1621
/** Whether unknown fields should be dropped. */
1722
abstract boolean shouldDiscardUnknownFields(Reader reader);
1823

@@ -55,7 +60,9 @@ abstract class UnknownFieldSchema<T, B> {
5560
/** Marks unknown fields as immutable. */
5661
abstract void makeImmutable(Object message);
5762

58-
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
63+
/** Merges one field into the unknown fields. */
64+
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
65+
throws IOException {
5966
int tag = reader.getTag();
6067
int fieldNumber = WireFormat.getTagFieldNumber(tag);
6168
switch (WireFormat.getTagWireType(tag)) {
@@ -74,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
7481
case WireFormat.WIRETYPE_START_GROUP:
7582
final B subFields = newBuilder();
7683
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
77-
mergeFrom(subFields, reader);
84+
currentDepth++;
85+
if (currentDepth >= recursionLimit) {
86+
throw InvalidProtocolBufferException.recursionLimitExceeded();
87+
}
88+
mergeFrom(subFields, reader, currentDepth);
89+
currentDepth--;
7890
if (endGroupTag != reader.getTag()) {
7991
throw InvalidProtocolBufferException.invalidEndTag();
8092
}
@@ -87,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
8799
}
88100
}
89101

90-
private final void mergeFrom(B unknownFields, Reader reader) throws IOException {
102+
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
103+
throws IOException {
91104
while (true) {
92105
if (reader.getFieldNumber() == Reader.READ_DONE
93-
|| !mergeOneFieldFrom(unknownFields, reader)) {
106+
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
94107
break;
95108
}
96109
}
@@ -107,4 +120,12 @@ private final void mergeFrom(B unknownFields, Reader reader) throws IOException
107120
abstract int getSerializedSizeAsMessageSet(T message);
108121

109122
abstract int getSerializedSize(T unknowns);
123+
124+
/**
125+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
126+
* the depth of the message exceeds this limit.
127+
*/
128+
public void setRecursionLimit(int limit) {
129+
recursionLimit = limit;
130+
}
110131
}

java/core/src/test/java/com/google/protobuf/CodedInputStreamTest.java

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import static com.google.common.truth.Truth.assertWithMessage;
1212
import static org.junit.Assert.assertArrayEquals;
1313
import static org.junit.Assert.assertThrows;
14+
15+
import com.google.common.primitives.Bytes;
16+
import map_test.MapTestProto.MapContainer;
1417
import protobuf_unittest.UnittestProto.BoolMessage;
1518
import protobuf_unittest.UnittestProto.Int32Message;
1619
import protobuf_unittest.UnittestProto.Int64Message;
@@ -35,6 +38,13 @@ public class CodedInputStreamTest {
3538

3639
private static final int DEFAULT_BLOCK_SIZE = 4096;
3740

41+
private static final int GROUP_TAP = WireFormat.makeTag(3, WireFormat.WIRETYPE_START_GROUP);
42+
43+
private static final byte[] NESTING_SGROUP = generateSGroupTags();
44+
45+
private static final byte[] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField();
46+
47+
3848
private enum InputType {
3949
ARRAY {
4050
@Override
@@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
117127
return bytes;
118128
}
119129

130+
private static byte[] generateSGroupTags() {
131+
byte[] bytes = new byte[100000];
132+
Arrays.fill(bytes, (byte) GROUP_TAP);
133+
return bytes;
134+
}
135+
136+
private static byte[] generateSGroupTagsForMapField() {
137+
byte[] initialBytes = {18, 1, 75, 26, (byte) 198, (byte) 154, 12};
138+
return Bytes.concat(initialBytes, NESTING_SGROUP);
139+
}
140+
120141
/**
121142
* An InputStream which limits the number of bytes it reads at a time. We use this to make sure
122143
* that CodedInputStream doesn't screw up when reading in small blocks.
@@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
740761
}
741762
}
742763

764+
@Test
765+
public void testMaliciousRecursion_unknownFields() throws Exception {
766+
Throwable thrown =
767+
assertThrows(
768+
InvalidProtocolBufferException.class,
769+
() -> TestRecursiveMessage.parseFrom(NESTING_SGROUP));
770+
771+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
772+
}
773+
774+
@Test
775+
public void testMaliciousRecursion_skippingUnknownField() throws Exception {
776+
Throwable thrown =
777+
assertThrows(
778+
InvalidProtocolBufferException.class,
779+
() ->
780+
DiscardUnknownFieldsParser.wrap(TestRecursiveMessage.parser())
781+
.parseFrom(NESTING_SGROUP));
782+
783+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
784+
}
785+
786+
@Test
787+
public void testMaliciousSGroupTagsWithMapField_fromInputStream() throws Exception {
788+
Throwable parseFromThrown =
789+
assertThrows(
790+
InvalidProtocolBufferException.class,
791+
() ->
792+
MapContainer.parseFrom(
793+
new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
794+
Throwable mergeFromThrown =
795+
assertThrows(
796+
InvalidProtocolBufferException.class,
797+
() ->
798+
MapContainer.newBuilder()
799+
.mergeFrom(new ByteArrayInputStream(NESTING_SGROUP_WITH_INITIAL_BYTES)));
800+
801+
assertThat(parseFromThrown)
802+
.hasMessageThat()
803+
.contains("Protocol message had too many levels of nesting");
804+
assertThat(mergeFromThrown)
805+
.hasMessageThat()
806+
.contains("Protocol message had too many levels of nesting");
807+
}
808+
809+
@Test
810+
public void testMaliciousSGroupTags_inputStream_skipMessage() throws Exception {
811+
ByteArrayInputStream inputSteam = new ByteArrayInputStream(NESTING_SGROUP);
812+
CodedInputStream input = CodedInputStream.newInstance(inputSteam);
813+
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
814+
815+
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
816+
Throwable thrown2 =
817+
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
818+
819+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
820+
assertThat(thrown2)
821+
.hasMessageThat()
822+
.contains("Protocol message had too many levels of nesting");
823+
}
824+
825+
@Test
826+
public void testMaliciousSGroupTagsWithMapField_fromByteArray() throws Exception {
827+
Throwable parseFromThrown =
828+
assertThrows(
829+
InvalidProtocolBufferException.class,
830+
() -> MapContainer.parseFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
831+
Throwable mergeFromThrown =
832+
assertThrows(
833+
InvalidProtocolBufferException.class,
834+
() -> MapContainer.newBuilder().mergeFrom(NESTING_SGROUP_WITH_INITIAL_BYTES));
835+
836+
assertThat(parseFromThrown)
837+
.hasMessageThat()
838+
.contains("the input ended unexpectedly in the middle of a field");
839+
assertThat(mergeFromThrown)
840+
.hasMessageThat()
841+
.contains("the input ended unexpectedly in the middle of a field");
842+
}
843+
844+
@Test
845+
public void testMaliciousSGroupTags_arrayDecoder_skipMessage() throws Exception {
846+
CodedInputStream input = CodedInputStream.newInstance(NESTING_SGROUP);
847+
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
848+
849+
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
850+
Throwable thrown2 =
851+
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
852+
853+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
854+
assertThat(thrown2)
855+
.hasMessageThat()
856+
.contains("Protocol message had too many levels of nesting");
857+
}
858+
859+
@Test
860+
public void testMaliciousSGroupTagsWithMapField_fromByteBuffer() throws Exception {
861+
Throwable thrown =
862+
assertThrows(
863+
InvalidProtocolBufferException.class,
864+
() -> MapContainer.parseFrom(ByteBuffer.wrap(NESTING_SGROUP_WITH_INITIAL_BYTES)));
865+
866+
assertThat(thrown)
867+
.hasMessageThat()
868+
.contains("the input ended unexpectedly in the middle of a field");
869+
}
870+
871+
@Test
872+
public void testMaliciousSGroupTags_byteBuffer_skipMessage() throws Exception {
873+
CodedInputStream input = InputType.NIO_DIRECT.newDecoder(NESTING_SGROUP);
874+
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
875+
876+
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
877+
Throwable thrown2 =
878+
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
879+
880+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
881+
assertThat(thrown2)
882+
.hasMessageThat()
883+
.contains("Protocol message had too many levels of nesting");
884+
}
885+
886+
@Test
887+
public void testMaliciousSGroupTags_iterableByteBuffer() throws Exception {
888+
CodedInputStream input = InputType.ITER_DIRECT.newDecoder(NESTING_SGROUP);
889+
CodedOutputStream output = CodedOutputStream.newInstance(new byte[NESTING_SGROUP.length]);
890+
891+
Throwable thrown = assertThrows(InvalidProtocolBufferException.class, input::skipMessage);
892+
Throwable thrown2 =
893+
assertThrows(InvalidProtocolBufferException.class, () -> input.skipMessage(output));
894+
895+
assertThat(thrown).hasMessageThat().contains("Protocol message had too many levels of nesting");
896+
assertThat(thrown2)
897+
.hasMessageThat()
898+
.contains("Protocol message had too many levels of nesting");
899+
}
900+
743901
private void checkSizeLimitExceeded(InvalidProtocolBufferException e) {
744902
assertThat(e)
745903
.hasMessageThat()

0 commit comments

Comments
 (0)