Skip to content

Commit 4a197e7

Browse files
authored
Merge pull request #18387 from protocolbuffers/cp-lp-25
Add recursion check when parsing unknown fields in Java.
2 parents e673479 + b5a7cf7 commit 4a197e7

File tree

10 files changed

+469
-95
lines changed

10 files changed

+469
-95
lines changed

java/core/src/main/java/com/google/protobuf/ArrayDecoders.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
*/
2424
@CheckReturnValue
2525
final class ArrayDecoders {
26+
static final int DEFAULT_RECURSION_LIMIT = 100;
2627

27-
private ArrayDecoders() {
28-
}
28+
@SuppressWarnings("NonFinalStaticField")
29+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
30+
31+
private ArrayDecoders() {}
2932

3033
/**
3134
* A helper used to return multiple values in a Java function. Java doesn't natively support
@@ -38,6 +41,7 @@ static final class Registers {
3841
public long long1;
3942
public Object object1;
4043
public final ExtensionRegistryLite extensionRegistry;
44+
public int recursionDepth;
4145

4246
Registers() {
4347
this.extensionRegistry = ExtensionRegistryLite.getEmptyRegistry();
@@ -245,7 +249,10 @@ static int mergeMessageField(
245249
if (length < 0 || length > limit - position) {
246250
throw InvalidProtocolBufferException.truncatedMessage();
247251
}
252+
registers.recursionDepth++;
253+
checkRecursionLimit(registers.recursionDepth);
248254
schema.mergeFrom(msg, data, position, position + length, registers);
255+
registers.recursionDepth--;
249256
registers.object1 = msg;
250257
return position + length;
251258
}
@@ -263,8 +270,11 @@ static int mergeGroupField(
263270
// A group field must has a MessageSchema (the only other subclass of Schema is MessageSetSchema
264271
// and it can't be used in group fields).
265272
final MessageSchema messageSchema = (MessageSchema) schema;
273+
registers.recursionDepth++;
274+
checkRecursionLimit(registers.recursionDepth);
266275
final int endPosition =
267276
messageSchema.parseMessage(msg, data, position, limit, endGroup, registers);
277+
registers.recursionDepth--;
268278
registers.object1 = msg;
269279
return endPosition;
270280
}
@@ -1025,6 +1035,8 @@ static int decodeUnknownField(
10251035
final UnknownFieldSetLite child = UnknownFieldSetLite.newInstance();
10261036
final int endGroup = (tag & ~0x7) | WireFormat.WIRETYPE_END_GROUP;
10271037
int lastTag = 0;
1038+
registers.recursionDepth++;
1039+
checkRecursionLimit(registers.recursionDepth);
10281040
while (position < limit) {
10291041
position = decodeVarint32(data, position, registers);
10301042
lastTag = registers.int1;
@@ -1033,6 +1045,7 @@ static int decodeUnknownField(
10331045
}
10341046
position = decodeUnknownField(lastTag, data, position, limit, child, registers);
10351047
}
1048+
registers.recursionDepth--;
10361049
if (position > limit || lastTag != endGroup) {
10371050
throw InvalidProtocolBufferException.parseFailure();
10381051
}
@@ -1079,4 +1092,18 @@ static int skipField(int tag, byte[] data, int position, int limit, Registers re
10791092
throw InvalidProtocolBufferException.invalidTag();
10801093
}
10811094
}
1095+
1096+
/**
1097+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
1098+
* the depth of the message exceeds this limit.
1099+
*/
1100+
public static void setRecursionLimit(int limit) {
1101+
recursionLimit = limit;
1102+
}
1103+
1104+
private static void checkRecursionLimit(int depth) throws InvalidProtocolBufferException {
1105+
if (depth >= recursionLimit) {
1106+
throw InvalidProtocolBufferException.recursionLimitExceeded();
1107+
}
1108+
}
10821109
}

java/core/src/main/java/com/google/protobuf/CodedInputStream.java

Lines changed: 30 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,41 @@ public abstract boolean skipField(final int tag, final CodedOutputStream output)
223223
* Reads and discards an entire message. This will read either until EOF or until an endgroup tag,
224224
* whichever comes first.
225225
*/
226-
public abstract void skipMessage() throws IOException;
226+
public void skipMessage() throws IOException {
227+
while (true) {
228+
final int tag = readTag();
229+
if (tag == 0) {
230+
return;
231+
}
232+
checkRecursionLimit();
233+
++recursionDepth;
234+
boolean fieldSkipped = skipField(tag);
235+
--recursionDepth;
236+
if (!fieldSkipped) {
237+
return;
238+
}
239+
}
240+
}
227241

228242
/**
229243
* Reads an entire message and writes it to output in wire format. This will read either until EOF
230244
* or until an endgroup tag, whichever comes first.
231245
*/
232-
public abstract void skipMessage(CodedOutputStream output) throws IOException;
246+
public void skipMessage(CodedOutputStream output) throws IOException {
247+
while (true) {
248+
final int tag = readTag();
249+
if (tag == 0) {
250+
return;
251+
}
252+
checkRecursionLimit();
253+
++recursionDepth;
254+
boolean fieldSkipped = skipField(tag, output);
255+
--recursionDepth;
256+
if (!fieldSkipped) {
257+
return;
258+
}
259+
}
260+
}
233261

234262
// -----------------------------------------------------------------
235263

@@ -699,26 +727,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
699727
}
700728
}
701729

702-
@Override
703-
public void skipMessage() throws IOException {
704-
while (true) {
705-
final int tag = readTag();
706-
if (tag == 0 || !skipField(tag)) {
707-
return;
708-
}
709-
}
710-
}
711-
712-
@Override
713-
public void skipMessage(CodedOutputStream output) throws IOException {
714-
while (true) {
715-
final int tag = readTag();
716-
if (tag == 0 || !skipField(tag, output)) {
717-
return;
718-
}
719-
}
720-
}
721-
722730
// -----------------------------------------------------------------
723731

724732
@Override
@@ -1411,26 +1419,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
14111419
}
14121420
}
14131421

1414-
@Override
1415-
public void skipMessage() throws IOException {
1416-
while (true) {
1417-
final int tag = readTag();
1418-
if (tag == 0 || !skipField(tag)) {
1419-
return;
1420-
}
1421-
}
1422-
}
1423-
1424-
@Override
1425-
public void skipMessage(CodedOutputStream output) throws IOException {
1426-
while (true) {
1427-
final int tag = readTag();
1428-
if (tag == 0 || !skipField(tag, output)) {
1429-
return;
1430-
}
1431-
}
1432-
}
1433-
14341422
// -----------------------------------------------------------------
14351423

14361424
@Override
@@ -2176,26 +2164,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
21762164
}
21772165
}
21782166

2179-
@Override
2180-
public void skipMessage() throws IOException {
2181-
while (true) {
2182-
final int tag = readTag();
2183-
if (tag == 0 || !skipField(tag)) {
2184-
return;
2185-
}
2186-
}
2187-
}
2188-
2189-
@Override
2190-
public void skipMessage(CodedOutputStream output) throws IOException {
2191-
while (true) {
2192-
final int tag = readTag();
2193-
if (tag == 0 || !skipField(tag, output)) {
2194-
return;
2195-
}
2196-
}
2197-
}
2198-
21992167
/** Collects the bytes skipped and returns the data in a ByteBuffer. */
22002168
private class SkippedDataSink implements RefillCallback {
22012169
private int lastPos = pos;
@@ -3307,26 +3275,6 @@ public boolean skipField(final int tag, final CodedOutputStream output) throws I
33073275
}
33083276
}
33093277

3310-
@Override
3311-
public void skipMessage() throws IOException {
3312-
while (true) {
3313-
final int tag = readTag();
3314-
if (tag == 0 || !skipField(tag)) {
3315-
return;
3316-
}
3317-
}
3318-
}
3319-
3320-
@Override
3321-
public void skipMessage(CodedOutputStream output) throws IOException {
3322-
while (true) {
3323-
final int tag = readTag();
3324-
if (tag == 0 || !skipField(tag, output)) {
3325-
return;
3326-
}
3327-
}
3328-
}
3329-
33303278
// -----------------------------------------------------------------
33313279

33323280
@Override

java/core/src/main/java/com/google/protobuf/InvalidProtocolBufferException.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public InvalidWireTypeException(String description) {
132132
static InvalidProtocolBufferException recursionLimitExceeded() {
133133
return new InvalidProtocolBufferException(
134134
"Protocol message had too many levels of nesting. May be malicious. "
135-
+ "Use CodedInputStream.setRecursionLimit() to increase the depth limit.");
135+
+ "Use setRecursionLimit() to increase the recursion depth limit.");
136136
}
137137

138138
static InvalidProtocolBufferException sizeLimitExceeded() {

java/core/src/main/java/com/google/protobuf/MessageSchema.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,7 +3006,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
30063006
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
30073007
}
30083008
// Unknown field.
3009-
if (unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3009+
if (unknownFieldSchema.mergeOneFieldFrom(
3010+
unknownFields, reader, /* currentDepth= */ 0)) {
30103011
continue;
30113012
}
30123013
}
@@ -3381,7 +3382,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33813382
if (unknownFields == null) {
33823383
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33833384
}
3384-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3385+
if (!unknownFieldSchema.mergeOneFieldFrom(
3386+
unknownFields, reader, /* currentDepth= */ 0)) {
33853387
return;
33863388
}
33873389
break;
@@ -3397,7 +3399,8 @@ private <UT, UB, ET extends FieldDescriptorLite<ET>> void mergeFromHelper(
33973399
if (unknownFields == null) {
33983400
unknownFields = unknownFieldSchema.getBuilderFromMessage(message);
33993401
}
3400-
if (!unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader)) {
3402+
if (!unknownFieldSchema.mergeOneFieldFrom(
3403+
unknownFields, reader, /* currentDepth= */ 0)) {
34013404
return;
34023405
}
34033406
}

java/core/src/main/java/com/google/protobuf/MessageSetSchema.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ boolean parseMessageSetItemOrUnknownField(
278278
reader, extension, extensionRegistry, extensions);
279279
return true;
280280
} else {
281-
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader);
281+
return unknownFieldSchema.mergeOneFieldFrom(unknownFields, reader, /* currentDepth= */ 0);
282282
}
283283
} else {
284284
return reader.skipField();

java/core/src/main/java/com/google/protobuf/UnknownFieldSchema.java

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
@CheckReturnValue
1414
abstract class UnknownFieldSchema<T, B> {
1515

16+
static final int DEFAULT_RECURSION_LIMIT = 100;
17+
18+
@SuppressWarnings("NonFinalStaticField")
19+
private static volatile int recursionLimit = DEFAULT_RECURSION_LIMIT;
20+
1621
/** Whether unknown fields should be dropped. */
1722
abstract boolean shouldDiscardUnknownFields(Reader reader);
1823

@@ -56,7 +61,8 @@ abstract class UnknownFieldSchema<T, B> {
5661
abstract void makeImmutable(Object message);
5762

5863
/** Merges one field into the unknown fields. */
59-
final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOException {
64+
final boolean mergeOneFieldFrom(B unknownFields, Reader reader, int currentDepth)
65+
throws IOException {
6066
int tag = reader.getTag();
6167
int fieldNumber = WireFormat.getTagFieldNumber(tag);
6268
switch (WireFormat.getTagWireType(tag)) {
@@ -75,7 +81,12 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
7581
case WireFormat.WIRETYPE_START_GROUP:
7682
final B subFields = newBuilder();
7783
int endGroupTag = WireFormat.makeTag(fieldNumber, WireFormat.WIRETYPE_END_GROUP);
78-
mergeFrom(subFields, reader);
84+
currentDepth++;
85+
if (currentDepth >= recursionLimit) {
86+
throw InvalidProtocolBufferException.recursionLimitExceeded();
87+
}
88+
mergeFrom(subFields, reader, currentDepth);
89+
currentDepth--;
7990
if (endGroupTag != reader.getTag()) {
8091
throw InvalidProtocolBufferException.invalidEndTag();
8192
}
@@ -88,10 +99,11 @@ final boolean mergeOneFieldFrom(B unknownFields, Reader reader) throws IOExcepti
8899
}
89100
}
90101

91-
final void mergeFrom(B unknownFields, Reader reader) throws IOException {
102+
private final void mergeFrom(B unknownFields, Reader reader, int currentDepth)
103+
throws IOException {
92104
while (true) {
93105
if (reader.getFieldNumber() == Reader.READ_DONE
94-
|| !mergeOneFieldFrom(unknownFields, reader)) {
106+
|| !mergeOneFieldFrom(unknownFields, reader, currentDepth)) {
95107
break;
96108
}
97109
}
@@ -108,4 +120,12 @@ final void mergeFrom(B unknownFields, Reader reader) throws IOException {
108120
abstract int getSerializedSizeAsMessageSet(T message);
109121

110122
abstract int getSerializedSize(T unknowns);
123+
124+
/**
125+
* Set the maximum recursion limit that ArrayDecoders will allow. An exception will be thrown if
126+
* the depth of the message exceeds this limit.
127+
*/
128+
public void setRecursionLimit(int limit) {
129+
recursionLimit = limit;
130+
}
111131
}

0 commit comments

Comments
 (0)