@@ -70,7 +70,7 @@ def __init__(
7070 type_vocab_size = 16 ,
7171 type_sequence_label_size = 2 ,
7272 initializer_range = 0.02 ,
73- image_feature_pool_shape = [7 , 7 , 256 ],
73+ image_feature_pool_shape = [7 , 7 , 32 ],
7474 coordinate_size = 6 ,
7575 shape_size = 6 ,
7676 num_labels = 3 ,
@@ -106,6 +106,14 @@ def __init__(
106106 self .num_choices = num_choices
107107 self .scope = scope
108108 self .range_bbox = range_bbox
109+ detectron2_config = LayoutLMv2Config .get_default_detectron2_config ()
110+ # We need to make the model smaller
111+ detectron2_config ["MODEL.RESNETS.DEPTH" ] = 50
112+ detectron2_config ["MODEL.RESNETS.RES2_OUT_CHANNELS" ] = 4
113+ detectron2_config ["MODEL.RESNETS.STEM_OUT_CHANNELS" ] = 4
114+ detectron2_config ["MODEL.FPN.OUT_CHANNELS" ] = 32
115+ detectron2_config ["MODEL.RESNETS.NUM_GROUPS" ] = 1
116+ self .detectron2_config = detectron2_config
109117
110118 def prepare_config_and_inputs (self ):
111119 input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
@@ -158,13 +166,9 @@ def prepare_config_and_inputs(self):
158166 image_feature_pool_shape = self .image_feature_pool_shape ,
159167 coordinate_size = self .coordinate_size ,
160168 shape_size = self .shape_size ,
169+ detectron2_config_args = self .detectron2_config ,
161170 )
162171
163- # use smaller resnet backbone to make tests faster
164- config .detectron2_config_args ["MODEL.RESNETS.DEPTH" ] = 18
165- config .detectron2_config_args ["MODEL.RESNETS.RES2_OUT_CHANNELS" ] = 64
166- config .detectron2_config_args ["MODEL.RESNETS.NUM_GROUPS" ] = 1
167-
168172 return config , input_ids , bbox , image , token_type_ids , input_mask , sequence_labels , token_labels
169173
170174 def create_and_check_model (
@@ -422,10 +426,6 @@ def check_hidden_states_output(inputs_dict, config, model_class):
422426
423427 check_hidden_states_output (inputs_dict , config , model_class )
424428
425- @unittest .skip (reason = "We cannot configure detectron2 to output a smaller backbone" )
426- def test_model_is_small (self ):
427- pass
428-
429429 @slow
430430 def test_model_from_pretrained (self ):
431431 model_name = "microsoft/layoutlmv2-base-uncased"
0 commit comments