@@ -533,26 +533,97 @@ def test_metadata(self):
533533 self .assertIn ("keras_version" , metadata )
534534 self .assertIn ("date_saved" , metadata )
535535
536- def test_gfile_copy_called (self ):
537- temp_filepath = Path (
538- os .path .join (self .get_temp_dir (), "my_model.keras" )
536+ def test_save_keras_gfile_copy_called (self ):
537+ path = Path (os .path .join (self .get_temp_dir (), "my_model.keras" ))
538+ model = keras .Sequential (
539+ [
540+ keras .Input (shape = (1 , 1 )),
541+ keras .layers .Dense (4 ),
542+ ]
539543 )
540- model = CompileOverridingModel ()
541544 with mock .patch (
542545 "re.match" , autospec = True
543546 ) as mock_re_match , mock .patch .object (
544547 tf .io .gfile , "copy"
545548 ) as mock_gfile_copy :
546549 # Check regex matching
547550 mock_re_match .return_value = True
548- model .save (temp_filepath , save_format = "keras_v3" )
551+ model .save (path , save_format = "keras_v3" )
549552 mock_re_match .assert_called ()
550- self .assertIn (str (temp_filepath ), mock_re_match .call_args .args )
553+ self .assertIn (str (path ), mock_re_match .call_args .args )
551554
552555 # Check gfile copied with filepath specified as destination
553- self .assertEqual (
554- str (temp_filepath ), str (mock_gfile_copy .call_args .args [1 ])
555- )
556+ mock_gfile_copy .assert_called ()
557+ self .assertEqual (str (path ), str (mock_gfile_copy .call_args .args [1 ]))
558+
559+ def test_save_tf_gfile_copy_not_called (self ):
560+ path = Path (os .path .join (self .get_temp_dir (), "my_model.keras" ))
561+ model = keras .Sequential (
562+ [
563+ keras .Input (shape = (1 , 1 )),
564+ keras .layers .Dense (4 ),
565+ ]
566+ )
567+ with mock .patch (
568+ "re.match" , autospec = True
569+ ) as mock_re_match , mock .patch .object (
570+ tf .io .gfile , "copy"
571+ ) as mock_gfile_copy :
572+ # Check regex matching
573+ mock_re_match .return_value = True
574+ model .save (path , save_format = "tf" )
575+ mock_re_match .assert_called ()
576+ self .assertIn (str (path ), mock_re_match .call_args .args )
577+
578+ # Check gfile.copy was not used.
579+ mock_gfile_copy .assert_not_called ()
580+
581+ def test_save_weights_h5_gfile_copy_called (self ):
582+ path = Path (os .path .join (self .get_temp_dir (), "my_model.weights.h5" ))
583+ model = keras .Sequential (
584+ [
585+ keras .Input (shape = (1 , 1 )),
586+ keras .layers .Dense (4 ),
587+ ]
588+ )
589+ model (tf .constant ([[1.0 ]]))
590+ with mock .patch (
591+ "re.match" , autospec = True
592+ ) as mock_re_match , mock .patch .object (
593+ tf .io .gfile , "copy"
594+ ) as mock_gfile_copy :
595+ # Check regex matching
596+ mock_re_match .return_value = True
597+ model .save_weights (path )
598+ mock_re_match .assert_called ()
599+ self .assertIn (str (path ), mock_re_match .call_args .args )
600+
601+ # Check gfile copied with filepath specified as destination
602+ mock_gfile_copy .assert_called ()
603+ self .assertEqual (str (path ), str (mock_gfile_copy .call_args .args [1 ]))
604+
605+ def test_save_weights_tf_gfile_copy_not_called (self ):
606+ path = Path (os .path .join (self .get_temp_dir (), "my_model.ckpt" ))
607+ model = keras .Sequential (
608+ [
609+ keras .Input (shape = (1 , 1 )),
610+ keras .layers .Dense (4 ),
611+ ]
612+ )
613+ model (tf .constant ([[1.0 ]]))
614+ with mock .patch (
615+ "re.match" , autospec = True
616+ ) as mock_re_match , mock .patch .object (
617+ tf .io .gfile , "copy"
618+ ) as mock_gfile_copy :
619+ # Check regex matching
620+ mock_re_match .return_value = True
621+ model .save_weights (path )
622+ mock_re_match .assert_called ()
623+ self .assertIn (str (path ), mock_re_match .call_args .args )
624+
625+ # Check gfile.copy was not used.
626+ mock_gfile_copy .assert_not_called ()
556627
557628 def test_load_model_api_endpoint (self ):
558629 temp_filepath = Path (os .path .join (self .get_temp_dir (), "mymodel.keras" ))
0 commit comments