@@ -260,66 +260,85 @@ def test_get_detect_run_exception(self, mock_validate):
260260 @patch ("skyflow.vault.controller._detect.base64" )
261261 @patch ("skyflow.vault.controller._detect.os.path.basename" )
262262 @patch ("skyflow.vault.controller._detect.open" , create = True )
263- @patch ("skyflow.vault.controller._detect.time.sleep" , return_value = None )
264- def test_deidentify_file_all_branches (self , mock_sleep , mock_open , mock_basename , mock_base64 , mock_validate ):
265- # Helper to run a branch
266- def run_branch (file_name , file_extension , api_call_attr , req_file_class ):
267- file_content = b"test content"
268- file_obj = Mock ()
269- file_obj .read .return_value = file_content
270- file_obj .name = file_name
271- mock_basename .return_value = os .path .basename (file_name )
272- mock_base64 .b64encode .return_value = b"dGVzdCBjb250ZW50"
273- req = DeidentifyFileRequest (file = file_obj )
274- req .entities = []
275- req .token_format = Mock (default = "default" , entity_unique_counter = [], entity_only = [])
276- req .allow_regex_list = []
277- req .restrict_regex_list = []
278- req .transformations = None
279- req .output_directory = "/tmp"
280- files_api = Mock ()
281- files_api .with_raw_response = files_api
282- setattr (files_api , api_call_attr , Mock ())
283- self .vault_client .get_detect_file_api .return_value = files_api
284- api_response = Mock ()
285- api_response .data = Mock (run_id = "runid123" )
286- getattr (files_api , api_call_attr ).return_value = api_response
287-
288- # Patch get_run for polling
289- poll_response = Mock ()
290- poll_response .status = "SUCCESS"
291- poll_response .output = [
292- {"processedFile" : "dGVzdCBjb250ZW50" , "processedFileType" : file_extension ,
293- "processedFileExtension" : file_extension }
263+ @patch .object (Detect , "_Detect__poll_for_processed_file" )
264+ def test_deidentify_file_all_branches (self , mock_poll , mock_open , mock_basename , mock_base64 , mock_validate ):
265+ """Test all file type branches with optimized mocking"""
266+
267+ # Common mocks
268+ file_content = b"test content"
269+ mock_base64 .b64encode .return_value = b"dGVzdCBjb250ZW50"
270+
271+ # Prepare a generic processed_response for all branches
272+ processed_response = Mock ()
273+ processed_response .status = "SUCCESS"
274+ processed_response .output = [
275+ {"processedFile" : "dGVzdCBjb250ZW50" , "processedFileType" : "pdf" , "processedFileExtension" : "pdf" }
276+ ]
277+ processed_response .word_character_count = Mock (word_count = 1 , character_count = 1 )
278+ processed_response .size = 1
279+ processed_response .duration = 1
280+ processed_response .pages = 1
281+ processed_response .slides = 1
282+ processed_response .message = ""
283+ processed_response .run_id = "runid123"
284+ mock_poll .return_value = processed_response
285+
286+ # Patch __parse_deidentify_file_response to return a valid DeidentifyFileResponse
287+ with patch .object (self .detect , "_Detect__parse_deidentify_file_response" ,
288+ return_value = DeidentifyFileResponse (
289+ file = "dGVzdCBjb250ZW50" , type = "pdf" , extension = "pdf" ,
290+ word_count = 1 , char_count = 1 , size_in_kb = 1 ,
291+ duration_in_seconds = 1 , page_count = 1 , slide_count = 1 ,
292+ entities = [], run_id = "runid123" , status = "SUCCESS" , errors = []
293+ )) as mock_parse :
294+ # Test configuration for different file types
295+ test_cases = [
296+ ("test.pdf" , "pdf" , "deidentify_pdf" ),
297+ ("test.jpg" , "jpg" , "deidentify_image" ),
298+ ("test.pptx" , "pptx" , "deidentify_presentation" ),
299+ ("test.csv" , "csv" , "deidentify_spreadsheet" ),
300+ ("test.docx" , "docx" , "deidentify_document" ),
301+ ("test.json" , "json" , "deidentify_structured_text" ),
302+ ("test.xml" , "xml" , "deidentify_structured_text" ),
303+ ("test.unknown" , "unknown" , "deidentify_file" )
294304 ]
295- poll_response .word_character_count = Mock (word_count = 1 , character_count = 1 )
296- poll_response .size = 1
297- poll_response .duration = 1
298- poll_response .pages = 1
299- poll_response .slides = 1
300- poll_response .message = ""
301- poll_response .run_id = "runid123"
302- files_api .get_run .return_value = poll_response
303-
304- # Actually run the method (no patching of __poll_for_processed_file or __parse_deidentify_file_response)
305- result = self .detect .deidentify_file (req )
306- self .assertIsInstance (result , DeidentifyFileResponse )
307- self .assertEqual (result .status , "SUCCESS" )
308- self .assertEqual (result .file , "dGVzdCBjb250ZW50" )
309- self .assertEqual (result .type , file_extension )
310- self .assertEqual (result .extension , file_extension )
311-
312- # Test all branches
313- run_branch ("test.pdf" , "pdf" , "deidentify_pdf" , "DeidentifyPdfRequestFile" )
314- run_branch ("test.jpg" , "jpg" , "deidentify_image" , "DeidentifyImageRequestFile" )
315- run_branch ("test.pptx" , "pptx" , "deidentify_presentation" , "DeidentifyPresentationRequestFile" )
316- run_branch ("test.csv" , "csv" , "deidentify_spreadsheet" , "DeidentifySpreadsheetRequestFile" )
317- run_branch ("test.docx" , "docx" , "deidentify_document" , "DeidentifyDocumentRequestFile" )
318- run_branch ("test.json" , "json" , "deidentify_structured_text" , "DeidentifyStructuredTextRequestFile" )
319- run_branch ("test.xml" , "xml" , "deidentify_structured_text" , "DeidentifyStructuredTextRequestFile" )
320- # Test else branch (unknown extension)
321- run_branch ("test.unknown" , "unknown" , "deidentify_file" , "DeidentifyFileRequestFile" )
322305
306+ for file_name , extension , api_method in test_cases :
307+ with self .subTest (file_type = extension ):
308+ # Setup file mock
309+ file_obj = Mock ()
310+ file_obj .read .return_value = file_content
311+ file_obj .name = file_name
312+ mock_basename .return_value = file_name
313+
314+ # Setup request
315+ req = DeidentifyFileRequest (file = file_obj )
316+ req .entities = []
317+ req .token_format = Mock (default = "default" , entity_unique_counter = [], entity_only = [])
318+ req .allow_regex_list = []
319+ req .restrict_regex_list = []
320+ req .transformations = None
321+ req .output_directory = "/tmp"
322+
323+ # Setup API mock
324+ files_api = Mock ()
325+ files_api .with_raw_response = files_api
326+ api_method_mock = Mock ()
327+ setattr (files_api , api_method , api_method_mock )
328+ self .vault_client .get_detect_file_api .return_value = files_api
329+
330+ # Setup API response
331+ api_response = Mock ()
332+ api_response .data = Mock (run_id = "runid123" )
333+ api_method_mock .return_value = api_response
334+
335+ # Actually run the method
336+ result = self .detect .deidentify_file (req )
337+ self .assertIsInstance (result , DeidentifyFileResponse )
338+ self .assertEqual (result .status , "SUCCESS" )
339+ self .assertEqual (result .file , "dGVzdCBjb250ZW50" )
340+ self .assertEqual (result .type , "pdf" )
341+ self .assertEqual (result .extension , "pdf" )
323342 @patch ("skyflow.vault.controller._detect.validate_deidentify_file_request" )
324343 @patch ("skyflow.vault.controller._detect.base64" )
325344 def test_deidentify_file_exception (self , mock_base64 , mock_validate ):
@@ -339,38 +358,47 @@ def test_deidentify_file_exception(self, mock_base64, mock_validate):
339358 @patch ("skyflow.vault.controller._detect.time.sleep" , return_value = None )
340359 def test_poll_for_processed_file_success (self , mock_sleep ):
341360 files_api = Mock ()
361+ files_api .with_raw_response = files_api
342362 self .vault_client .get_detect_file_api .return_value = files_api
343- # First call returns IN_PROGRESS, second call returns SUCCESS
344- in_progress = Mock ()
345- in_progress .status = "IN_PROGRESS"
346- in_progress .message = ""
347- success = Mock ()
348- success .status = "SUCCESS"
349- files_api .get_run .side_effect = [in_progress , success ]
363+
364+ call_count = {"count" : 0 }
365+
366+ def get_run_side_effect (* args , ** kwargs ):
367+ if call_count ["count" ] < 1 :
368+ call_count ["count" ] += 1
369+ in_progress = Mock ()
370+ in_progress .status = "IN_PROGRESS"
371+ in_progress .message = ""
372+ return Mock (data = in_progress )
373+ else :
374+ success = Mock ()
375+ success .status = "SUCCESS"
376+ return Mock (data = success )
377+
378+ files_api .get_run .side_effect = get_run_side_effect
379+
380+ # Use max_wait_time > 1 to allow the loop to reach the SUCCESS status
350381 result = self .detect ._Detect__poll_for_processed_file ("runid123" , max_wait_time = 2 )
351382 self .assertEqual (result .status , "SUCCESS" )
352383
353384 @patch ("skyflow.vault.controller._detect.time.sleep" , return_value = None )
354385 def test_poll_for_processed_file_failed (self , mock_sleep ):
355386 files_api = Mock ()
387+ files_api .with_raw_response = files_api
356388 self .vault_client .get_detect_file_api .return_value = files_api
357- failed = Mock ()
358- failed .status = "FAILED"
359- failed .message = "fail"
360- files_api .get_run .return_value = failed
361- with self .assertRaises (SkyflowError ):
362- self .detect ._Detect__poll_for_processed_file ("runid123" , max_wait_time = 1 )
363389
364- @patch ("skyflow.vault.controller._detect.time.sleep" , return_value = None )
365- def test_poll_for_processed_file_unknown (self , mock_sleep ):
366- files_api = Mock ()
367- self .vault_client .get_detect_file_api .return_value = files_api
368- unknown = Mock ()
369- unknown .status = "UNKNOWN"
370- unknown .message = "fail"
371- files_api .get_run .return_value = unknown
372- with self .assertRaises (SkyflowError ):
373- self .detect ._Detect__poll_for_processed_file ("runid123" , max_wait_time = 1 )
390+ # Always return FAILED on first call
391+ def get_run_side_effect (* args , ** kwargs ):
392+ failed = Mock ()
393+ failed .status = "FAILED"
394+ failed .message = "fail"
395+ return Mock (data = failed )
396+
397+ files_api .get_run .side_effect = get_run_side_effect
398+
399+ result = self .detect ._Detect__poll_for_processed_file ("runid123" , max_wait_time = 1 )
400+ self .assertEqual (result .status , "FAILED" )
401+ self .assertEqual (result .message , "fail" )
374402
375403 def test_parse_deidentify_file_response_dict_and_obj (self ):
376404 # Dict input
@@ -413,6 +441,7 @@ class DummyData:
413441 obj_data = DummyData ()
414442 result = self .detect ._Detect__parse_deidentify_file_response (obj_data , "runid" , "SUCCESS" )
415443 self .assertIsInstance (result , DeidentifyFileResponse )
444+
416445 def test_get_token_format_missing_attribute (self ):
417446 """Test __get_token_format when token_format attribute is missing"""
418447 class DummyRequest :
0 commit comments