11
11
import static com .google .common .truth .Truth .assertWithMessage ;
12
12
import static org .junit .Assert .assertArrayEquals ;
13
13
import static org .junit .Assert .assertThrows ;
14
+
15
+ import com .google .common .primitives .Bytes ;
16
+ import map_test .MapTestProto .MapContainer ;
14
17
import protobuf_unittest .UnittestProto .BoolMessage ;
15
18
import protobuf_unittest .UnittestProto .Int32Message ;
16
19
import protobuf_unittest .UnittestProto .Int64Message ;
@@ -35,6 +38,13 @@ public class CodedInputStreamTest {
35
38
36
39
private static final int DEFAULT_BLOCK_SIZE = 4096 ;
37
40
41
+ private static final int GROUP_TAP = WireFormat .makeTag (3 , WireFormat .WIRETYPE_START_GROUP );
42
+
43
+ private static final byte [] NESTING_SGROUP = generateSGroupTags ();
44
+
45
+ private static final byte [] NESTING_SGROUP_WITH_INITIAL_BYTES = generateSGroupTagsForMapField ();
46
+
47
+
38
48
private enum InputType {
39
49
ARRAY {
40
50
@ Override
@@ -117,6 +127,17 @@ private byte[] bytes(int... bytesAsInts) {
117
127
return bytes ;
118
128
}
119
129
130
+ private static byte [] generateSGroupTags () {
131
+ byte [] bytes = new byte [100000 ];
132
+ Arrays .fill (bytes , (byte ) GROUP_TAP );
133
+ return bytes ;
134
+ }
135
+
136
+ private static byte [] generateSGroupTagsForMapField () {
137
+ byte [] initialBytes = {18 , 1 , 75 , 26 , (byte ) 198 , (byte ) 154 , 12 };
138
+ return Bytes .concat (initialBytes , NESTING_SGROUP );
139
+ }
140
+
120
141
/**
121
142
* An InputStream which limits the number of bytes it reads at a time. We use this to make sure
122
143
* that CodedInputStream doesn't screw up when reading in small blocks.
@@ -740,6 +761,143 @@ public void testMaliciousRecursion() throws Exception {
740
761
}
741
762
}
742
763
764
+ @ Test
765
+ public void testMaliciousRecursion_unknownFields () throws Exception {
766
+ Throwable thrown =
767
+ assertThrows (
768
+ InvalidProtocolBufferException .class ,
769
+ () -> TestRecursiveMessage .parseFrom (NESTING_SGROUP ));
770
+
771
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
772
+ }
773
+
774
+ @ Test
775
+ public void testMaliciousRecursion_skippingUnknownField () throws Exception {
776
+ Throwable thrown =
777
+ assertThrows (
778
+ InvalidProtocolBufferException .class ,
779
+ () ->
780
+ DiscardUnknownFieldsParser .wrap (TestRecursiveMessage .parser ())
781
+ .parseFrom (NESTING_SGROUP ));
782
+
783
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
784
+ }
785
+
786
+ @ Test
787
+ public void testMaliciousSGroupTagsWithMapField_fromInputStream () throws Exception {
788
+ Throwable parseFromThrown =
789
+ assertThrows (
790
+ InvalidProtocolBufferException .class ,
791
+ () ->
792
+ MapContainer .parseFrom (
793
+ new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
794
+ Throwable mergeFromThrown =
795
+ assertThrows (
796
+ InvalidProtocolBufferException .class ,
797
+ () ->
798
+ MapContainer .newBuilder ()
799
+ .mergeFrom (new ByteArrayInputStream (NESTING_SGROUP_WITH_INITIAL_BYTES )));
800
+
801
+ assertThat (parseFromThrown )
802
+ .hasMessageThat ()
803
+ .contains ("Protocol message had too many levels of nesting" );
804
+ assertThat (mergeFromThrown )
805
+ .hasMessageThat ()
806
+ .contains ("Protocol message had too many levels of nesting" );
807
+ }
808
+
809
+ @ Test
810
+ public void testMaliciousSGroupTags_inputStream_skipMessage () throws Exception {
811
+ ByteArrayInputStream inputSteam = new ByteArrayInputStream (NESTING_SGROUP );
812
+ CodedInputStream input = CodedInputStream .newInstance (inputSteam );
813
+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
814
+
815
+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
816
+ Throwable thrown2 =
817
+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
818
+
819
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
820
+ assertThat (thrown2 )
821
+ .hasMessageThat ()
822
+ .contains ("Protocol message had too many levels of nesting" );
823
+ }
824
+
825
+ @ Test
826
+ public void testMaliciousSGroupTagsWithMapField_fromByteArray () throws Exception {
827
+ Throwable parseFromThrown =
828
+ assertThrows (
829
+ InvalidProtocolBufferException .class ,
830
+ () -> MapContainer .parseFrom (NESTING_SGROUP_WITH_INITIAL_BYTES ));
831
+ Throwable mergeFromThrown =
832
+ assertThrows (
833
+ InvalidProtocolBufferException .class ,
834
+ () -> MapContainer .newBuilder ().mergeFrom (NESTING_SGROUP_WITH_INITIAL_BYTES ));
835
+
836
+ assertThat (parseFromThrown )
837
+ .hasMessageThat ()
838
+ .contains ("the input ended unexpectedly in the middle of a field" );
839
+ assertThat (mergeFromThrown )
840
+ .hasMessageThat ()
841
+ .contains ("the input ended unexpectedly in the middle of a field" );
842
+ }
843
+
844
+ @ Test
845
+ public void testMaliciousSGroupTags_arrayDecoder_skipMessage () throws Exception {
846
+ CodedInputStream input = CodedInputStream .newInstance (NESTING_SGROUP );
847
+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
848
+
849
+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
850
+ Throwable thrown2 =
851
+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
852
+
853
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
854
+ assertThat (thrown2 )
855
+ .hasMessageThat ()
856
+ .contains ("Protocol message had too many levels of nesting" );
857
+ }
858
+
859
+ @ Test
860
+ public void testMaliciousSGroupTagsWithMapField_fromByteBuffer () throws Exception {
861
+ Throwable thrown =
862
+ assertThrows (
863
+ InvalidProtocolBufferException .class ,
864
+ () -> MapContainer .parseFrom (ByteBuffer .wrap (NESTING_SGROUP_WITH_INITIAL_BYTES )));
865
+
866
+ assertThat (thrown )
867
+ .hasMessageThat ()
868
+ .contains ("the input ended unexpectedly in the middle of a field" );
869
+ }
870
+
871
+ @ Test
872
+ public void testMaliciousSGroupTags_byteBuffer_skipMessage () throws Exception {
873
+ CodedInputStream input = InputType .NIO_DIRECT .newDecoder (NESTING_SGROUP );
874
+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
875
+
876
+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
877
+ Throwable thrown2 =
878
+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
879
+
880
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
881
+ assertThat (thrown2 )
882
+ .hasMessageThat ()
883
+ .contains ("Protocol message had too many levels of nesting" );
884
+ }
885
+
886
+ @ Test
887
+ public void testMaliciousSGroupTags_iterableByteBuffer () throws Exception {
888
+ CodedInputStream input = InputType .ITER_DIRECT .newDecoder (NESTING_SGROUP );
889
+ CodedOutputStream output = CodedOutputStream .newInstance (new byte [NESTING_SGROUP .length ]);
890
+
891
+ Throwable thrown = assertThrows (InvalidProtocolBufferException .class , input ::skipMessage );
892
+ Throwable thrown2 =
893
+ assertThrows (InvalidProtocolBufferException .class , () -> input .skipMessage (output ));
894
+
895
+ assertThat (thrown ).hasMessageThat ().contains ("Protocol message had too many levels of nesting" );
896
+ assertThat (thrown2 )
897
+ .hasMessageThat ()
898
+ .contains ("Protocol message had too many levels of nesting" );
899
+ }
900
+
743
901
private void checkSizeLimitExceeded (InvalidProtocolBufferException e ) {
744
902
assertThat (e )
745
903
.hasMessageThat ()
0 commit comments