Skip to content

Commit e577bd0

Browse files
Use native TF checkpoints for the BLIP TF tests (#22593)
* Use native TF checkpoints for the TF tests * Remove unneeded exceptions
1 parent 176ceff commit e577bd0

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

tests/models/blip/test_modeling_tf_blip.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,7 @@ def test_save_load_fast_init_to_base(self):
189189
@slow
190190
def test_model_from_pretrained(self):
191191
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
192-
try:
193-
model = TFBlipVisionModel.from_pretrained(model_name)
194-
except OSError:
195-
model = TFBlipVisionModel.from_pretrained(model_name, from_pt=True)
192+
model = TFBlipVisionModel.from_pretrained(model_name)
196193
self.assertIsNotNone(model)
197194

198195

@@ -320,10 +317,7 @@ def test_save_load_fast_init_to_base(self):
320317
@slow
321318
def test_model_from_pretrained(self):
322319
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
323-
try:
324-
model = TFBlipTextModel.from_pretrained(model_name)
325-
except OSError:
326-
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
320+
model = TFBlipTextModel.from_pretrained(model_name)
327321
self.assertIsNotNone(model)
328322

329323
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
@@ -432,7 +426,7 @@ def test_load_vision_text_config(self):
432426
@slow
433427
def test_model_from_pretrained(self):
434428
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
435-
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
429+
model = TFBlipModel.from_pretrained(model_name)
436430
self.assertIsNotNone(model)
437431

438432
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):
@@ -635,7 +629,7 @@ def test_load_vision_text_config(self):
635629
@slow
636630
def test_model_from_pretrained(self):
637631
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
638-
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
632+
model = TFBlipModel.from_pretrained(model_name)
639633
self.assertIsNotNone(model)
640634

641635
@unittest.skip(reason="Tested in individual model tests")
@@ -750,10 +744,7 @@ def test_load_vision_text_config(self):
750744
@slow
751745
def test_model_from_pretrained(self):
752746
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
753-
try:
754-
model = TFBlipModel.from_pretrained(model_name)
755-
except OSError:
756-
model = TFBlipModel.from_pretrained(model_name, from_pt=True)
747+
model = TFBlipModel.from_pretrained(model_name)
757748
self.assertIsNotNone(model)
758749

759750

@@ -769,7 +760,7 @@ def prepare_img():
769760
@slow
770761
class TFBlipModelIntegrationTest(unittest.TestCase):
771762
def test_inference_image_captioning(self):
772-
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", from_pt=True)
763+
model = TFBlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
773764
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
774765
image = prepare_img()
775766

@@ -796,7 +787,7 @@ def test_inference_image_captioning(self):
796787
)
797788

798789
def test_inference_vqa(self):
799-
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", from_pt=True)
790+
model = TFBlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
800791
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
801792

802793
image = prepare_img()
@@ -808,7 +799,7 @@ def test_inference_vqa(self):
808799
self.assertEqual(out[0].numpy().tolist(), [30522, 1015, 102])
809800

810801
def test_inference_itm(self):
811-
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco", from_pt=True)
802+
model = TFBlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
812803
processor = BlipProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
813804

814805
image = prepare_img()

tests/models/blip/test_modeling_tf_blip_text.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,7 @@ def test_save_load_fast_init_to_base(self):
160160
@slow
161161
def test_model_from_pretrained(self):
162162
for model_name in TF_BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
163-
try:
164-
model = TFBlipTextModel.from_pretrained(model_name)
165-
except OSError:
166-
model = TFBlipTextModel.from_pretrained(model_name, from_pt=True)
163+
model = TFBlipTextModel.from_pretrained(model_name)
167164
self.assertIsNotNone(model)
168165

169166
def test_pt_tf_model_equivalence(self, allow_missing_keys=True):

0 commit comments

Comments
 (0)