Skip to content

Commit b221244

Browse files
committed
SK-1758 Add detect support in Python SDK
- Add unit tests for deidentify and reidentify text
1 parent 45795d3 commit b221244

File tree

5 files changed

+261
-10
lines changed

5 files changed

+261
-10
lines changed

skyflow/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from ._skyflow_messages import SkyflowMessages
33
from ._version import SDK_VERSION
44
from ._helpers import get_base_url, format_scope
5-
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values
5+
from ._utils import get_credentials, get_vault_url, construct_invoke_connection_request, get_metrics, parse_insert_response, handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_get_response, parse_invoke_connection_response, validate_api_key, encode_column_values, parse_deidentify_text_response, parse_reidentify_text_response, convert_to_entity_type, convert_detected_entity_to_entity_info

skyflow/utils/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def convert_to_entity_type(detect_entities):
9393
entity_types.append(entity.value)
9494
return entity_types
9595

96-
def _convert_detected_entity_to_entity_info(detected_entity):
96+
def convert_detected_entity_to_entity_info(detected_entity):
9797
return EntityInfo(
9898
token=detected_entity.token,
9999
value=detected_entity.value,
@@ -392,7 +392,7 @@ def parse_invoke_connection_response(api_response: requests.Response):
392392
raise SkyflowError(message, status_code)
393393

394394
def parse_deidentify_text_response(api_response: DeidentifyStringResponse):
395-
entities = [_convert_detected_entity_to_entity_info(entity) for entity in api_response.entities]
395+
entities = [convert_detected_entity_to_entity_info(entity) for entity in api_response.entities]
396396
return DeidentifyTextResponse(
397397
processed_text=api_response.processed_text,
398398
entities=entities,

skyflow/vault/controller/_detect.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@ def ___build_deidentify_text_body(self, request: DeidentifyTextRequest) -> Dict[
2828
deidentify_text_body = {}
2929
parsed_entity_types = convert_to_entity_type(request.entities)
3030

31-
parsed_token_type = TokenType(
32-
default = request.token_format.default,
33-
vault_token = convert_to_entity_type(request.token_format.vault_token),
34-
entity_unq_counter = convert_to_entity_type(request.token_format.entity_unique_counter),
35-
entity_only = convert_to_entity_type(request.token_format.entity_only)
36-
)
31+
parsed_token_type = None
32+
if request.token_format is not None:
33+
parsed_token_type = TokenType(
34+
default = request.token_format.default,
35+
vault_token = convert_to_entity_type(request.token_format.vault_token),
36+
entity_unq_counter = convert_to_entity_type(request.token_format.entity_unique_counter),
37+
entity_only = convert_to_entity_type(request.token_format.entity_only)
38+
)
3739
parsed_transformations = None
3840
if request.transformations is not None:
3941
parsed_transformations = Transformations(

tests/utils/test__utils.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from skyflow.utils import get_credentials, SkyflowMessages, get_vault_url, construct_invoke_connection_request, \
1111
parse_insert_response, parse_update_record_response, parse_delete_response, parse_get_response, \
1212
parse_detokenize_response, parse_tokenize_response, parse_query_response, parse_invoke_connection_response, \
13-
handle_exception, validate_api_key, encode_column_values
13+
handle_exception, validate_api_key, encode_column_values, parse_deidentify_text_response, \
14+
parse_reidentify_text_response, convert_to_entity_type, convert_detected_entity_to_entity_info
1415
from skyflow.utils._utils import parse_path_params, to_lowercase_keys, get_metrics
1516
from skyflow.utils.enums import EnvUrls, Env, ContentType
1617
from skyflow.vault.connection import InvokeConnectionResponse
@@ -418,3 +419,133 @@ def test_encode_column_values(self):
418419

419420
result = encode_column_values(get_request)
420421
self.assertEqual(result, expected_encoded_values)
422+
423+
def test_parse_deidentify_text_response(self):
424+
"""Test parsing deidentify text response with multiple entities."""
425+
mock_entity = Mock()
426+
mock_entity.token = "token123"
427+
mock_entity.value = "sensitive_value"
428+
mock_entity.entity_type = "EMAIL"
429+
mock_entity.entity_scores = {"EMAIL": 0.95}
430+
mock_entity.location = Mock(
431+
start_index=10,
432+
end_index=20,
433+
start_index_processed=15,
434+
end_index_processed=25
435+
)
436+
437+
mock_api_response = Mock()
438+
mock_api_response.processed_text = "Sample processed text"
439+
mock_api_response.entities = [mock_entity]
440+
mock_api_response.word_count = 3
441+
mock_api_response.character_count = 20
442+
443+
result = parse_deidentify_text_response(mock_api_response)
444+
445+
self.assertEqual(result.processed_text, "Sample processed text")
446+
self.assertEqual(result.word_count, 3)
447+
self.assertEqual(result.char_count, 20)
448+
self.assertEqual(len(result.entities), 1)
449+
450+
entity = result.entities[0]
451+
self.assertEqual(entity.token, "token123")
452+
self.assertEqual(entity.value, "sensitive_value")
453+
self.assertEqual(entity.entity, "EMAIL")
454+
self.assertEqual(entity.scores, {"EMAIL": 0.95})
455+
self.assertEqual(entity.text_index.start, 10)
456+
self.assertEqual(entity.text_index.end, 20)
457+
self.assertEqual(entity.processed_index.start, 15)
458+
self.assertEqual(entity.processed_index.end, 25)
459+
460+
def test_parse_deidentify_text_response_no_entities(self):
461+
"""Test parsing deidentify text response with no entities."""
462+
mock_api_response = Mock()
463+
mock_api_response.processed_text = "Sample processed text"
464+
mock_api_response.entities = []
465+
mock_api_response.word_count = 3
466+
mock_api_response.character_count = 20
467+
468+
result = parse_deidentify_text_response(mock_api_response)
469+
470+
self.assertEqual(result.processed_text, "Sample processed text")
471+
self.assertEqual(result.word_count, 3)
472+
self.assertEqual(result.char_count, 20)
473+
self.assertEqual(len(result.entities), 0)
474+
475+
def test_parse_reidentify_text_response(self):
476+
"""Test parsing reidentify text response."""
477+
mock_api_response = Mock()
478+
mock_api_response.processed_text = "Reidentified text with actual values"
479+
480+
result = parse_reidentify_text_response(mock_api_response)
481+
482+
self.assertEqual(result.processed_text, "Reidentified text with actual values")
483+
484+
def test_convert_to_entity_type_with_valid_entities(self):
485+
"""Test converting entity types with valid input."""
486+
from skyflow.utils.enums import DetectEntities
487+
488+
detect_entities = [DetectEntities.EMAIL_ADDRESS, DetectEntities.PHONE_NUMBER]
489+
result = convert_to_entity_type(detect_entities)
490+
491+
self.assertEqual(result, ["email_address", "phone_number"])
492+
493+
def test_convert_to_entity_type_with_empty_list(self):
494+
"""Test converting entity types with empty list."""
495+
result = convert_to_entity_type([])
496+
self.assertIsNone(result)
497+
498+
def test_convert_to_entity_type_with_none(self):
499+
"""Test converting entity types with None input."""
500+
result = convert_to_entity_type(None)
501+
self.assertIsNone(result)
502+
503+
def test__convert_detected_entity_to_entity_info(self):
504+
"""Test converting detected entity to EntityInfo object."""
505+
mock_detected_entity = Mock()
506+
mock_detected_entity.token = "token123"
507+
mock_detected_entity.value = "sensitive_value"
508+
mock_detected_entity.entity_type = "EMAIL"
509+
mock_detected_entity.entity_scores = {"EMAIL": 0.95}
510+
mock_detected_entity.location = Mock(
511+
start_index=10,
512+
end_index=20,
513+
start_index_processed=15,
514+
end_index_processed=25
515+
)
516+
517+
result = convert_detected_entity_to_entity_info(mock_detected_entity)
518+
519+
self.assertEqual(result.token, "token123")
520+
self.assertEqual(result.value, "sensitive_value")
521+
self.assertEqual(result.entity, "EMAIL")
522+
self.assertEqual(result.scores, {"EMAIL": 0.95})
523+
self.assertEqual(result.text_index.start, 10)
524+
self.assertEqual(result.text_index.end, 20)
525+
self.assertEqual(result.processed_index.start, 15)
526+
self.assertEqual(result.processed_index.end, 25)
527+
528+
def test__convert_detected_entity_to_entity_info_with_minimal_data(self):
529+
"""Test converting detected entity with minimal required data."""
530+
mock_detected_entity = Mock()
531+
mock_detected_entity.token = "token123"
532+
mock_detected_entity.value = None
533+
mock_detected_entity.entity_type = "UNKNOWN"
534+
mock_detected_entity.entity_scores = {}
535+
mock_detected_entity.location = Mock(
536+
start_index=0,
537+
end_index=0,
538+
start_index_processed=0,
539+
end_index_processed=0
540+
)
541+
542+
result = convert_detected_entity_to_entity_info(mock_detected_entity)
543+
544+
self.assertEqual(result.token, "token123")
545+
self.assertIsNone(result.value)
546+
self.assertEqual(result.entity, "UNKNOWN")
547+
self.assertEqual(result.scores, {})
548+
self.assertEqual(result.text_index.start, 0)
549+
self.assertEqual(result.text_index.end, 0)
550+
self.assertEqual(result.processed_index.start, 0)
551+
self.assertEqual(result.processed_index.end, 0)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import unittest
2+
from unittest.mock import Mock, patch
3+
from skyflow.error import SkyflowError
4+
from skyflow.vault.controller import Detect
5+
from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, \
6+
TokenFormat, DateTransformation, Transformations
7+
from skyflow.utils.enums import DetectEntities, TokenType
8+
9+
VAULT_ID = "test_vault_id"
10+
11+
class TestDetect(unittest.TestCase):
12+
def setUp(self):
13+
# Mock vault client
14+
self.vault_client = Mock()
15+
self.vault_client.get_vault_id.return_value = VAULT_ID
16+
self.vault_client.get_logger.return_value = Mock()
17+
18+
# Create a Detect instance with the mock client
19+
self.detect = Detect(self.vault_client)
20+
21+
@patch("skyflow.vault.controller._detect.validate_deidentify_text_request")
22+
@patch("skyflow.vault.controller._detect.parse_deidentify_text_response")
23+
def test_deidentify_text_success(self, mock_parse_response, mock_validate):
24+
# Mock API response
25+
mock_api_response = Mock()
26+
mock_api_response.data = {
27+
'text': '[TOKEN_1] lives in [TOKEN_2]',
28+
'entities': [
29+
{
30+
'token': 'Token1',
31+
'value': 'John',
32+
'text_index': {'start': 0, 'end': 4},
33+
'processed_index': {'start': 0, 'end': 8},
34+
'entity': 'NAME',
35+
'scores': {'confidence': 0.9}
36+
}
37+
],
38+
'word_count': 4,
39+
'char_count': 20
40+
}
41+
42+
# Create request
43+
request = DeidentifyTextRequest(
44+
text="John lives in NYC",
45+
entities=[DetectEntities.NAME],
46+
token_format=TokenFormat(default=TokenType.ENTITY_ONLY)
47+
)
48+
49+
# Mock detect API
50+
detect_api = self.vault_client.get_detect_text_api.return_value
51+
detect_api.deidentify_string.return_value = mock_api_response
52+
53+
# Call deidentify_text
54+
response = self.detect.deidentify_text(request)
55+
56+
# Assertions
57+
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
58+
mock_parse_response.assert_called_once_with(mock_api_response)
59+
detect_api.deidentify_string.assert_called_once()
60+
61+
@patch("skyflow.vault.controller._detect.validate_reidentify_text_request")
62+
@patch("skyflow.vault.controller._detect.parse_reidentify_text_response")
63+
def test_reidentify_text_success(self, mock_parse_response, mock_validate):
64+
# Mock API response
65+
mock_api_response = Mock()
66+
mock_api_response.data = {
67+
'text': 'John lives in NYC'
68+
}
69+
70+
# Create request
71+
request = ReidentifyTextRequest(
72+
text="Token1 lives in Token2",
73+
redacted_entities=[DetectEntities.NAME]
74+
)
75+
76+
# Mock detect API
77+
detect_api = self.vault_client.get_detect_text_api.return_value
78+
detect_api.reidentify_string.return_value = mock_api_response
79+
80+
# Call reidentify_text
81+
response = self.detect.reidentify_text(request)
82+
83+
# Assertions
84+
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
85+
mock_parse_response.assert_called_once_with(mock_api_response)
86+
detect_api.reidentify_string.assert_called_once()
87+
88+
@patch("skyflow.vault.controller._detect.validate_deidentify_text_request")
89+
def test_deidentify_text_handles_generic_error(self, mock_validate):
90+
request = DeidentifyTextRequest(
91+
text="John lives in NYC",
92+
entities=[DetectEntities.NAME]
93+
)
94+
detect_api = self.vault_client.get_detect_text_api.return_value
95+
detect_api.deidentify_string.side_effect = Exception("Generic Error")
96+
97+
with self.assertRaises(Exception):
98+
self.detect.deidentify_text(request)
99+
100+
detect_api.deidentify_string.assert_called_once()
101+
102+
@patch("skyflow.vault.controller._detect.validate_reidentify_text_request")
103+
def test_reidentify_text_handles_generic_error(self, mock_validate):
104+
request = ReidentifyTextRequest(
105+
text="Token1 lives in Token2",
106+
redacted_entities=[DetectEntities.NAME]
107+
)
108+
detect_api = self.vault_client.get_detect_text_api.return_value
109+
detect_api.reidentify_string.side_effect = Exception("Generic Error")
110+
111+
with self.assertRaises(Exception):
112+
self.detect.reidentify_text(request)
113+
114+
detect_api.reidentify_string.assert_called_once()
115+
116+
117+
# if __name__ == '__main__':
118+
# unittest.main()

0 commit comments

Comments
 (0)