1616import re
1717import os
1818import shutil
19- import unittest
19+ import random
20+ import tempfile
2021import pytest
2122
2223
3435
3536
3637NAME_ONLY_ARGS = {
37- 'display_name' : 'TestModel123'
38+ 'display_name' : 'TestModel123_{0}' .format (random .randint (1111 , 9999 ))
39+ }
40+ NAME_ONLY_ARGS_UPDATED = {
41+ 'display_name' : 'TestModel123_updated_{0}' .format (random .randint (1111 , 9999 ))
3842}
3943NAME_AND_TAGS_ARGS = {
40- 'display_name' : 'TestModel123_tags' ,
44+ 'display_name' : 'TestModel123_tags_{0}' . format ( random . randint ( 1111 , 9999 )) ,
4145 'tags' : ['test_tag123' ]
42- }
46+ }
4347FULL_MODEL_ARGS = {
44- 'display_name' : 'TestModel123_full' ,
48+ 'display_name' : 'TestModel123_full_{0}' . format ( random . randint ( 1111 , 9999 )) ,
4549 'tags' : ['test_tag567' ],
4650 'file_name' : 'model1.tflite'
47- }
51+ }
4852INVALID_FULL_MODEL_ARGS = {
49- 'display_name' : 'TestModel123_invalid_full' ,
53+ 'display_name' : 'TestModel123_invalid_full_{0}' . format ( random . randint ( 1111 , 9999 )) ,
5054 'tags' : ['test_tag890' ],
5155 'file_name' : 'invalid_model.tflite'
52- }
56+ }
5357
5458@pytest .fixture
5559def firebase_model (request ):
5660 args = request .param
5761 tflite_format = None
58- if args .get ('file_name' ):
59- file_path = testutils .resource_filename (args .get ('file_name' ))
62+ file_name = args .get ('file_name' )
63+ if file_name :
64+ file_path = testutils .resource_filename (file_name )
6065 source = ml .TFLiteGCSModelSource .from_tflite_model_file (file_path )
6166 tflite_format = ml .TFLiteFormat (model_source = source )
6267
@@ -109,35 +114,44 @@ def check_operation_error(excinfo, msg):
109114 assert str (err ) == msg
110115
111116
117+ def check_model (model , args ):
118+ assert model .display_name == args .get ('display_name' )
119+ assert model .tags == args .get ('tags' )
120+ assert model .model_id is not None
121+ assert model .create_time is not None
122+ assert model .update_time is not None
123+ assert model .locked is False
124+ assert model .etag is not None
125+
126+
127+ def check_model_format (model , has_model_format , validation_error ):
128+ if has_model_format :
129+ assert model .validation_error == validation_error
130+ assert model .published is False
131+ assert model .model_format .model_source .gcs_tflite_uri .startswith ('gs://' )
132+ if validation_error :
133+ assert model .model_format .size_bytes is None
134+ assert model .model_hash is None
135+ else :
136+ assert model .model_format .size_bytes is not None
137+ assert model .model_hash is not None
138+ else :
139+ assert model .model_format is None
140+ assert model .validation_error == 'No model file has been uploaded.'
141+ assert model .published is False
142+ assert model .model_hash is None
143+
144+
112145@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
113146def test_create_simple_model (firebase_model ):
114- assert firebase_model .display_name == NAME_AND_TAGS_ARGS .get ('display_name' )
115- assert firebase_model .tags == NAME_AND_TAGS_ARGS .get ('tags' )
116- assert firebase_model .model_id is not None
117- assert firebase_model .create_time is not None
118- assert firebase_model .update_time is not None
119- assert firebase_model .validation_error == 'No model file has been uploaded.'
120- assert firebase_model .locked is False
121- assert firebase_model .published is False
122- assert firebase_model .etag is not None
123- assert firebase_model .model_hash is None
124- assert firebase_model .model_format is None
147+ check_model (firebase_model , NAME_AND_TAGS_ARGS )
148+ check_model_format (firebase_model , False , None )
125149
126150
127151@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
128152def test_create_full_model (firebase_model ):
129- assert firebase_model .display_name == FULL_MODEL_ARGS .get ('display_name' )
130- assert firebase_model .tags == FULL_MODEL_ARGS .get ('tags' )
131- assert firebase_model .model_format .size_bytes is not None
132- assert firebase_model .model_format .model_source .gcs_tflite_uri is not None
133- assert firebase_model .model_id is not None
134- assert firebase_model .create_time is not None
135- assert firebase_model .update_time is not None
136- assert firebase_model .validation_error is None
137- assert firebase_model .locked is False
138- assert firebase_model .published is False
139- assert firebase_model .etag is not None
140- assert firebase_model .model_hash is not None
153+ check_model (firebase_model , FULL_MODEL_ARGS )
154+ check_model_format (firebase_model , True , None )
141155
142156
143157@pytest .mark .parametrize ('firebase_model' , [FULL_MODEL_ARGS ], indirect = True )
@@ -151,33 +165,15 @@ def test_create_already_existing_fails(firebase_model):
151165
152166@pytest .mark .parametrize ('firebase_model' , [INVALID_FULL_MODEL_ARGS ], indirect = True )
153167def test_create_invalid_model (firebase_model ):
154- assert firebase_model .display_name == INVALID_FULL_MODEL_ARGS .get ('display_name' )
155- assert firebase_model .tags == INVALID_FULL_MODEL_ARGS .get ('tags' )
156- assert firebase_model .model_format .size_bytes is None
157- assert firebase_model .model_format .model_source .gcs_tflite_uri is not None
158- assert firebase_model .model_id is not None
159- assert firebase_model .create_time is not None
160- assert firebase_model .update_time is not None
161- assert firebase_model .validation_error == 'Invalid flatbuffer format'
162- assert firebase_model .locked is False
163- assert firebase_model .published is False
164- assert firebase_model .etag is not None
165- assert firebase_model .model_hash is None
168+ check_model (firebase_model , INVALID_FULL_MODEL_ARGS )
169+ check_model_format (firebase_model , True , 'Invalid flatbuffer format' )
166170
167171
168172@pytest .mark .parametrize ('firebase_model' , [NAME_AND_TAGS_ARGS ], indirect = True )
169173def test_get_model (firebase_model ):
170174 get_model = ml .get_model (firebase_model .model_id )
171- assert get_model .display_name == firebase_model .display_name
172- assert get_model .tags == firebase_model .tags
173- assert get_model .model_id is not None
174- assert get_model .create_time is not None
175- assert get_model .update_time is not None
176- assert get_model .validation_error == 'No model file has been uploaded.'
177- assert get_model .etag is not None
178- assert get_model .locked is False
179- assert get_model .published is False
180- assert get_model .model_hash is None
175+ check_model (get_model , NAME_AND_TAGS_ARGS )
176+ check_model_format (get_model , False , None )
181177
182178
183179@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -192,29 +188,16 @@ def test_get_non_existing_model(firebase_model):
192188
193189@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
194190def test_update_model (firebase_model ):
195- new_model_name = 'TestModel123_updated'
191+ new_model_name = NAME_ONLY_ARGS_UPDATED . get ( 'display_name' )
196192 firebase_model .display_name = new_model_name
197-
198193 updated_model = ml .update_model (firebase_model )
199- assert updated_model .display_name == new_model_name
200- assert updated_model .model_id == firebase_model .model_id
201- assert updated_model .create_time == firebase_model .create_time
202- assert updated_model .update_time != firebase_model .update_time
203- assert updated_model .validation_error == firebase_model .validation_error
204- assert updated_model .etag != firebase_model .etag
205- assert updated_model .published == firebase_model .published
206- assert updated_model .locked == firebase_model .locked
194+ check_model (updated_model , NAME_ONLY_ARGS_UPDATED )
195+ check_model_format (updated_model , False , None )
207196
208197 # Second call with same model does not cause error
209198 updated_model2 = ml .update_model (updated_model )
210- assert updated_model2 .display_name == updated_model .display_name
211- assert updated_model2 .model_id == updated_model .model_id
212- assert updated_model2 .create_time == updated_model .create_time
213- assert updated_model2 .update_time != updated_model .update_time
214- assert updated_model2 .validation_error == updated_model .validation_error
215- assert updated_model2 .etag != updated_model .etag
216- assert updated_model2 .published == updated_model .published
217- assert updated_model2 .locked == updated_model .locked
199+ check_model (updated_model2 , NAME_ONLY_ARGS_UPDATED )
200+ check_model_format (updated_model2 , False , None )
218201
219202
220203@pytest .mark .parametrize ('firebase_model' , [NAME_ONLY_ARGS ], indirect = True )
@@ -272,11 +255,10 @@ def test_list_models(model_list):
272255 filter_str = 'displayName={0} OR tags:{1}' .format (
273256 model_list [0 ].display_name , model_list [1 ].tags [0 ])
274257
275- models_list = ml .list_models (list_filter = filter_str )
276- assert len (models_list .models ) == 2
277- for mdl in models_list .models :
278- assert mdl == model_list [0 ] or mdl == model_list [1 ]
279- assert models_list .models [0 ] != models_list .models [1 ]
258+ all_models = ml .list_models (list_filter = filter_str )
259+ all_model_ids = [mdl .model_id for mdl in all_models .iterate_all ()]
260+ for mdl in model_list :
261+ assert mdl .model_id in all_model_ids
280262
281263
282264def test_list_models_invalid_filter ():
@@ -302,12 +284,9 @@ def test_delete_model(firebase_model):
302284#'pip install tensorflow==2.0.0b' for version 2 etc.
303285
304286
305- SAVED_MODEL_DIR = '/tmp/saved_model/1'
306-
307-
308- def _clean_up_tmp_directory ():
309- if os .path .exists (SAVED_MODEL_DIR ):
310- shutil .rmtree (SAVED_MODEL_DIR )
287+ def _clean_up_directory (save_dir ):
288+ if save_dir .startswith (tempfile .gettempdir ()) and os .path .exists (save_dir ):
289+ shutil .rmtree (save_dir )
311290
312291
313292@pytest .fixture
@@ -327,8 +306,8 @@ def saved_model_dir(keras_model):
327306 assert _TF_ENABLED
328307 # different versions have different model conversion capability
329308 # pick something that works for each version
330- save_dir = SAVED_MODEL_DIR
331- _clean_up_tmp_directory () # previous failures may leave files
309+ parent = tempfile . mkdtemp ()
310+ save_dir = os . path . join ( parent , 'child' )
332311 if tf .version .VERSION .startswith ('1.' ):
333312 tf .reset_default_graph ()
334313 x_var = tf .placeholder (tf .float32 , (None , 3 ), name = "x" )
@@ -340,28 +319,29 @@ def saved_model_dir(keras_model):
340319 assert tf .version .VERSION .startswith ('2.' )
341320 tf .saved_model .save (keras_model , save_dir )
342321 yield save_dir
343- _clean_up_tmp_directory ( )
322+ _clean_up_directory ( parent )
344323
345324
346- @unittest . skipUnless ( _TF_ENABLED , 'Tensor flow is required for this test.' )
325+ @pytest . mark . skipif ( not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
347326def test_from_keras_model (keras_model ):
348327 source = ml .TFLiteGCSModelSource .from_keras_model (keras_model , 'model2.tflite' )
349328 assert re .search (
350329 '^gs://.*/Firebase/ML/Models/model2.tflite$' ,
351330 source .gcs_tflite_uri ) is not None
352331
353332 # Validate the conversion by creating a model
333+ model_format = ml .TFLiteFormat (model_source = source )
334+ model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
335+ created_model = ml .create_model (model )
336+
354337 try :
355- model_format = ml .TFLiteFormat (model_source = source )
356- model = ml .Model (display_name = "KerasModel1" , model_format = model_format )
357- created_model = ml .create_model (model )
358338 assert created_model .model_id is not None
359339 assert created_model .validation_error is None
360340 finally :
361341 _clean_up_model (created_model )
362342
363343
364- @unittest . skipUnless ( _TF_ENABLED , 'Tensor flow is required for this test.' )
344+ @pytest . mark . skipif ( not _TF_ENABLED , reason = 'Tensor flow is required for this test.' )
365345def test_from_saved_model (saved_model_dir ):
366346 # Test the conversion helper
367347 source = ml .TFLiteGCSModelSource .from_saved_model (saved_model_dir , 'model3.tflite' )
@@ -370,10 +350,11 @@ def test_from_saved_model(saved_model_dir):
370350 source .gcs_tflite_uri ) is not None
371351
372352 # Validate the conversion by creating a model
353+ model_format = ml .TFLiteFormat (model_source = source )
354+ model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
355+ created_model = ml .create_model (model )
356+
373357 try :
374- model_format = ml .TFLiteFormat (model_source = source )
375- model = ml .Model (display_name = "SavedModel1" , model_format = model_format )
376- created_model = ml .create_model (model )
377358 assert created_model .model_id is not None
378359 assert created_model .validation_error is None
379360 finally :
0 commit comments