Skip to content

Commit c9bc5b3

Browse files
SK-2043: update detect tests
1 parent b2910a5 commit c9bc5b3

File tree

2 files changed

+113
-88
lines changed

2 files changed

+113
-88
lines changed

skyflow/vault/controller/_detect.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def _get_file_extension(self, filename: str):
6161
return filename.split('.')[-1].lower() if '.' in filename else ''
6262

6363
def __poll_for_processed_file(self, run_id, max_wait_time=64):
64-
files_api = self.__vault_client.get_detect_file_api()
64+
files_api = self.__vault_client.get_detect_file_api().with_raw_response
6565
current_wait_time = 1 # Start with 1 second
6666
try:
6767
while True:
68-
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers())
68+
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
6969
status = response.status
7070
if status == 'IN_PROGRESS':
7171
if current_wait_time >= max_wait_time:
@@ -79,12 +79,8 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
7979
wait_time = next_wait_time
8080
current_wait_time = next_wait_time
8181
time.sleep(wait_time)
82-
elif status == 'SUCCESS':
82+
elif status == 'SUCCESS' or status == 'FAILED':
8383
return response
84-
elif status == 'FAILED':
85-
raise SkyflowError(SkyflowMessages.Error.INTERNAL_SERVER_ERROR.value.format(response.message), 500)
86-
else:
87-
raise SkyflowError(SkyflowMessages.Error.GET_DETECT_RUN_FAILED.value, 500)
8884
except Exception as e:
8985
raise e
9086

tests/vault/controller/test__detect.py

Lines changed: 110 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -260,66 +260,85 @@ def test_get_detect_run_exception(self, mock_validate):
260260
@patch("skyflow.vault.controller._detect.base64")
261261
@patch("skyflow.vault.controller._detect.os.path.basename")
262262
@patch("skyflow.vault.controller._detect.open", create=True)
263-
@patch("skyflow.vault.controller._detect.time.sleep", return_value=None)
264-
def test_deidentify_file_all_branches(self, mock_sleep, mock_open, mock_basename, mock_base64, mock_validate):
265-
# Helper to run a branch
266-
def run_branch(file_name, file_extension, api_call_attr, req_file_class):
267-
file_content = b"test content"
268-
file_obj = Mock()
269-
file_obj.read.return_value = file_content
270-
file_obj.name = file_name
271-
mock_basename.return_value = os.path.basename(file_name)
272-
mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50"
273-
req = DeidentifyFileRequest(file=file_obj)
274-
req.entities = []
275-
req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[])
276-
req.allow_regex_list = []
277-
req.restrict_regex_list = []
278-
req.transformations = None
279-
req.output_directory = "/tmp"
280-
files_api = Mock()
281-
files_api.with_raw_response = files_api
282-
setattr(files_api, api_call_attr, Mock())
283-
self.vault_client.get_detect_file_api.return_value = files_api
284-
api_response = Mock()
285-
api_response.data = Mock(run_id="runid123")
286-
getattr(files_api, api_call_attr).return_value = api_response
287-
288-
# Patch get_run for polling
289-
poll_response = Mock()
290-
poll_response.status = "SUCCESS"
291-
poll_response.output = [
292-
{"processedFile": "dGVzdCBjb250ZW50", "processedFileType": file_extension,
293-
"processedFileExtension": file_extension}
263+
@patch.object(Detect, "_Detect__poll_for_processed_file")
264+
def test_deidentify_file_all_branches(self, mock_poll, mock_open, mock_basename, mock_base64, mock_validate):
265+
"""Test all file type branches with optimized mocking"""
266+
267+
# Common mocks
268+
file_content = b"test content"
269+
mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50"
270+
271+
# Prepare a generic processed_response for all branches
272+
processed_response = Mock()
273+
processed_response.status = "SUCCESS"
274+
processed_response.output = [
275+
{"processedFile": "dGVzdCBjb250ZW50", "processedFileType": "pdf", "processedFileExtension": "pdf"}
276+
]
277+
processed_response.word_character_count = Mock(word_count=1, character_count=1)
278+
processed_response.size = 1
279+
processed_response.duration = 1
280+
processed_response.pages = 1
281+
processed_response.slides = 1
282+
processed_response.message = ""
283+
processed_response.run_id = "runid123"
284+
mock_poll.return_value = processed_response
285+
286+
# Patch __parse_deidentify_file_response to return a valid DeidentifyFileResponse
287+
with patch.object(self.detect, "_Detect__parse_deidentify_file_response",
288+
return_value=DeidentifyFileResponse(
289+
file="dGVzdCBjb250ZW50", type="pdf", extension="pdf",
290+
word_count=1, char_count=1, size_in_kb=1,
291+
duration_in_seconds=1, page_count=1, slide_count=1,
292+
entities=[], run_id="runid123", status="SUCCESS", errors=[]
293+
)) as mock_parse:
294+
# Test configuration for different file types
295+
test_cases = [
296+
("test.pdf", "pdf", "deidentify_pdf"),
297+
("test.jpg", "jpg", "deidentify_image"),
298+
("test.pptx", "pptx", "deidentify_presentation"),
299+
("test.csv", "csv", "deidentify_spreadsheet"),
300+
("test.docx", "docx", "deidentify_document"),
301+
("test.json", "json", "deidentify_structured_text"),
302+
("test.xml", "xml", "deidentify_structured_text"),
303+
("test.unknown", "unknown", "deidentify_file")
294304
]
295-
poll_response.word_character_count = Mock(word_count=1, character_count=1)
296-
poll_response.size = 1
297-
poll_response.duration = 1
298-
poll_response.pages = 1
299-
poll_response.slides = 1
300-
poll_response.message = ""
301-
poll_response.run_id = "runid123"
302-
files_api.get_run.return_value = poll_response
303-
304-
# Actually run the method (no patching of __poll_for_processed_file or __parse_deidentify_file_response)
305-
result = self.detect.deidentify_file(req)
306-
self.assertIsInstance(result, DeidentifyFileResponse)
307-
self.assertEqual(result.status, "SUCCESS")
308-
self.assertEqual(result.file, "dGVzdCBjb250ZW50")
309-
self.assertEqual(result.type, file_extension)
310-
self.assertEqual(result.extension, file_extension)
311-
312-
# Test all branches
313-
run_branch("test.pdf", "pdf", "deidentify_pdf", "DeidentifyPdfRequestFile")
314-
run_branch("test.jpg", "jpg", "deidentify_image", "DeidentifyImageRequestFile")
315-
run_branch("test.pptx", "pptx", "deidentify_presentation", "DeidentifyPresentationRequestFile")
316-
run_branch("test.csv", "csv", "deidentify_spreadsheet", "DeidentifySpreadsheetRequestFile")
317-
run_branch("test.docx", "docx", "deidentify_document", "DeidentifyDocumentRequestFile")
318-
run_branch("test.json", "json", "deidentify_structured_text", "DeidentifyStructuredTextRequestFile")
319-
run_branch("test.xml", "xml", "deidentify_structured_text", "DeidentifyStructuredTextRequestFile")
320-
# Test else branch (unknown extension)
321-
run_branch("test.unknown", "unknown", "deidentify_file", "DeidentifyFileRequestFile")
322305

306+
for file_name, extension, api_method in test_cases:
307+
with self.subTest(file_type=extension):
308+
# Setup file mock
309+
file_obj = Mock()
310+
file_obj.read.return_value = file_content
311+
file_obj.name = file_name
312+
mock_basename.return_value = file_name
313+
314+
# Setup request
315+
req = DeidentifyFileRequest(file=file_obj)
316+
req.entities = []
317+
req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[])
318+
req.allow_regex_list = []
319+
req.restrict_regex_list = []
320+
req.transformations = None
321+
req.output_directory = "/tmp"
322+
323+
# Setup API mock
324+
files_api = Mock()
325+
files_api.with_raw_response = files_api
326+
api_method_mock = Mock()
327+
setattr(files_api, api_method, api_method_mock)
328+
self.vault_client.get_detect_file_api.return_value = files_api
329+
330+
# Setup API response
331+
api_response = Mock()
332+
api_response.data = Mock(run_id="runid123")
333+
api_method_mock.return_value = api_response
334+
335+
# Actually run the method
336+
result = self.detect.deidentify_file(req)
337+
self.assertIsInstance(result, DeidentifyFileResponse)
338+
self.assertEqual(result.status, "SUCCESS")
339+
self.assertEqual(result.file, "dGVzdCBjb250ZW50")
340+
self.assertEqual(result.type, "pdf")
341+
self.assertEqual(result.extension, "pdf")
323342
@patch("skyflow.vault.controller._detect.validate_deidentify_file_request")
324343
@patch("skyflow.vault.controller._detect.base64")
325344
def test_deidentify_file_exception(self, mock_base64, mock_validate):
@@ -339,38 +358,47 @@ def test_deidentify_file_exception(self, mock_base64, mock_validate):
339358
@patch("skyflow.vault.controller._detect.time.sleep", return_value=None)
340359
def test_poll_for_processed_file_success(self, mock_sleep):
341360
files_api = Mock()
361+
files_api.with_raw_response = files_api
342362
self.vault_client.get_detect_file_api.return_value = files_api
343-
# First call returns IN_PROGRESS, second call returns SUCCESS
344-
in_progress = Mock()
345-
in_progress.status = "IN_PROGRESS"
346-
in_progress.message = ""
347-
success = Mock()
348-
success.status = "SUCCESS"
349-
files_api.get_run.side_effect = [in_progress, success]
363+
364+
call_count = {"count": 0}
365+
366+
def get_run_side_effect(*args, **kwargs):
367+
if call_count["count"] < 1:
368+
call_count["count"] += 1
369+
in_progress = Mock()
370+
in_progress.status = "IN_PROGRESS"
371+
in_progress.message = ""
372+
return Mock(data=in_progress)
373+
else:
374+
success = Mock()
375+
success.status = "SUCCESS"
376+
return Mock(data=success)
377+
378+
files_api.get_run.side_effect = get_run_side_effect
379+
380+
# Use max_wait_time > 1 to allow the loop to reach the SUCCESS status
350381
result = self.detect._Detect__poll_for_processed_file("runid123", max_wait_time=2)
351382
self.assertEqual(result.status, "SUCCESS")
352383

353384
@patch("skyflow.vault.controller._detect.time.sleep", return_value=None)
354385
def test_poll_for_processed_file_failed(self, mock_sleep):
355386
files_api = Mock()
387+
files_api.with_raw_response = files_api
356388
self.vault_client.get_detect_file_api.return_value = files_api
357-
failed = Mock()
358-
failed.status = "FAILED"
359-
failed.message = "fail"
360-
files_api.get_run.return_value = failed
361-
with self.assertRaises(SkyflowError):
362-
self.detect._Detect__poll_for_processed_file("runid123", max_wait_time=1)
363389

364-
@patch("skyflow.vault.controller._detect.time.sleep", return_value=None)
365-
def test_poll_for_processed_file_unknown(self, mock_sleep):
366-
files_api = Mock()
367-
self.vault_client.get_detect_file_api.return_value = files_api
368-
unknown = Mock()
369-
unknown.status = "UNKNOWN"
370-
unknown.message = "fail"
371-
files_api.get_run.return_value = unknown
372-
with self.assertRaises(SkyflowError):
373-
self.detect._Detect__poll_for_processed_file("runid123", max_wait_time=1)
390+
# Always return FAILED on first call
391+
def get_run_side_effect(*args, **kwargs):
392+
failed = Mock()
393+
failed.status = "FAILED"
394+
failed.message = "fail"
395+
return Mock(data=failed)
396+
397+
files_api.get_run.side_effect = get_run_side_effect
398+
399+
result = self.detect._Detect__poll_for_processed_file("runid123", max_wait_time=1)
400+
self.assertEqual(result.status, "FAILED")
401+
self.assertEqual(result.message, "fail")
374402

375403
def test_parse_deidentify_file_response_dict_and_obj(self):
376404
# Dict input
@@ -413,6 +441,7 @@ class DummyData:
413441
obj_data = DummyData()
414442
result = self.detect._Detect__parse_deidentify_file_response(obj_data, "runid", "SUCCESS")
415443
self.assertIsInstance(result, DeidentifyFileResponse)
444+
416445
def test_get_token_format_missing_attribute(self):
417446
"""Test __get_token_format when token_format attribute is missing"""
418447
class DummyRequest:

0 commit comments

Comments
 (0)