Skip to content

Commit 6e01253

Browse files
committed
SK-2521-Refactor file handling and error management in service account and vault controllers
1 parent 49f6b8b commit 6e01253

File tree

4 files changed

+101
-73
lines changed

4 files changed

+101
-73
lines changed

skyflow/service_account/_utils.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,15 @@ def is_expired(token, logger = None):
3232
def generate_bearer_token(credentials_file_path, options = None, logger = None):
3333
try:
3434
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger)
35-
credentials_file =open(credentials_file_path, 'r')
36-
except Exception:
35+
with open(credentials_file_path, 'r') as credentials_file:
36+
try:
37+
credentials = json.load(credentials_file)
38+
except Exception:
39+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
40+
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)
41+
except FileNotFoundError:
3742
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)
38-
39-
try:
40-
credentials = json.load(credentials_file)
41-
except Exception:
42-
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
43-
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)
44-
45-
finally:
46-
credentials_file.close()
43+
4744
result = get_service_account_token(credentials, options, logger)
4845
return result
4946

@@ -143,19 +140,15 @@ def get_signed_tokens(credentials_obj, options):
143140
def generate_signed_data_tokens(credentials_file_path, options):
144141
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value)
145142
try:
146-
credentials_file =open(credentials_file_path, 'r')
147-
except Exception:
143+
with open(credentials_file_path, 'r') as credentials_file:
144+
try:
145+
credentials = json.load(credentials_file)
146+
except Exception:
147+
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
148+
invalid_input_error_code)
149+
except FileNotFoundError:
148150
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)
149151

150-
try:
151-
credentials = json.load(credentials_file)
152-
except Exception:
153-
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
154-
invalid_input_error_code)
155-
156-
finally:
157-
credentials_file.close()
158-
159152
return get_signed_tokens(credentials, options)
160153

161154
def generate_signed_data_tokens_from_creds(credentials, options):

skyflow/vault/controller/_connections.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,15 @@ def invoke(self, request: InvokeConnectionRequest):
3131

3232
try:
3333
response = session.send(invoke_connection_request)
34-
session.close()
35-
invoke_connection_response = parse_invoke_connection_response(response)
36-
return invoke_connection_response
37-
34+
try:
35+
invoke_connection_response = parse_invoke_connection_response(response)
36+
return invoke_connection_response
37+
finally:
38+
response.close()
3839
except Exception as e:
3940
log_error_log(SkyflowMessages.ErrorLogs.INVOKE_CONNECTION_REQUEST_REJECTED.value, self.__vault_client.get_logger())
4041
if isinstance(e, SkyflowError): raise e
4142
raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value,
42-
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
43+
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
44+
finally:
45+
session.close()

skyflow/vault/controller/_detect.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,28 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
6262
current_wait_time = 1 # Start with 1 second
6363
try:
6464
while True:
65-
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
66-
status = response.status
67-
if status == 'IN_PROGRESS':
68-
if current_wait_time >= max_wait_time:
69-
return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')
70-
else:
71-
next_wait_time = current_wait_time * 2
72-
if next_wait_time >= max_wait_time:
73-
wait_time = max_wait_time - current_wait_time
74-
current_wait_time = max_wait_time
65+
http_response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers())
66+
try:
67+
response = http_response.data
68+
status = response.status
69+
if status == 'IN_PROGRESS':
70+
if current_wait_time >= max_wait_time:
71+
return DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')
7572
else:
76-
wait_time = next_wait_time
77-
current_wait_time = next_wait_time
78-
time.sleep(wait_time)
79-
elif status == 'SUCCESS' or status == 'FAILED':
80-
return response
73+
next_wait_time = current_wait_time * 2
74+
if next_wait_time >= max_wait_time:
75+
wait_time = max_wait_time - current_wait_time
76+
current_wait_time = max_wait_time
77+
else:
78+
wait_time = next_wait_time
79+
current_wait_time = next_wait_time
80+
time.sleep(wait_time)
81+
elif status == 'SUCCESS' or status == 'FAILED':
82+
# Create a copy of the response data before closing
83+
result = response
84+
return result
85+
finally:
86+
http_response.close()
8187
except Exception as e:
8288
raise e
8389

@@ -271,7 +277,7 @@ def __get_file_from_request(self, request: DeidentifyFileRequest):
271277

272278
# Check for file_path if file is not provided
273279
if hasattr(file_input, 'file_path') and file_input.file_path is not None:
274-
return open(file_input.file_path, 'rb')
280+
return open(file_input.file_path, 'rb')
275281

276282
def deidentify_file(self, request: DeidentifyFileRequest):
277283
log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger())
@@ -281,8 +287,19 @@ def deidentify_file(self, request: DeidentifyFileRequest):
281287
file_obj = self.__get_file_from_request(request)
282288
file_name = getattr(file_obj, 'name', None)
283289
file_extension = self._get_file_extension(file_name) if file_name else None
284-
file_content = file_obj.read()
285-
base64_string = base64.b64encode(file_content).decode('utf-8')
290+
291+
# Track if we need to close the file (only if it was opened from file_path)
292+
file_needs_closing = False
293+
file_input = request.file
294+
if hasattr(file_input, 'file_path') and file_input.file_path is not None:
295+
file_needs_closing = True
296+
297+
try:
298+
file_content = file_obj.read()
299+
base64_string = base64.b64encode(file_content).decode('utf-8')
300+
finally:
301+
if file_needs_closing and hasattr(file_obj, 'close'):
302+
file_obj.close()
286303

287304
try:
288305
if file_extension == 'txt':
@@ -420,16 +437,19 @@ def deidentify_file(self, request: DeidentifyFileRequest):
420437
log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
421438
api_response = api_call(**api_kwargs)
422439

423-
run_id = getattr(api_response.data, 'run_id', None)
440+
try:
441+
run_id = getattr(api_response.data, 'run_id', None)
424442

425-
processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
426-
if request.output_directory and processed_response.status == 'SUCCESS':
427-
name_without_ext, _ = os.path.splitext(file_name)
428-
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)
443+
processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
444+
if request.output_directory and processed_response.status == 'SUCCESS':
445+
name_without_ext, _ = os.path.splitext(file_name)
446+
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)
429447

430-
parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
431-
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
432-
return parsed_response
448+
parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
449+
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
450+
return parsed_response
451+
finally:
452+
api_response.close()
433453

434454
except Exception as e:
435455
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
@@ -445,17 +465,20 @@ def get_detect_run(self, request: GetDetectRunRequest):
445465
files_api = self.__vault_client.get_detect_file_api().with_raw_response
446466
run_id = request.run_id
447467
try:
448-
response = files_api.get_run(
468+
http_response = files_api.get_run(
449469
run_id,
450470
vault_id=self.__vault_client.get_vault_id(),
451471
request_options=self.__get_headers()
452472
)
453-
if response.data.status == 'IN_PROGRESS':
454-
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS'))
455-
else:
456-
parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status)
457-
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
458-
return parsed_response
473+
try:
474+
if http_response.data.status == 'IN_PROGRESS':
475+
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS'))
476+
else:
477+
parsed_response = self.__parse_deidentify_file_response(http_response.data, run_id, http_response.data.status)
478+
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
479+
return parsed_response
480+
finally:
481+
http_response.close()
459482
except Exception as e:
460483
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
461484
self.__vault_client.get_logger())

skyflow/vault/controller/_vault.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def insert(self, request: InsertRequest):
112112
api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(),
113113
request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers())
114114

115-
insert_response = parse_insert_response(api_response, request.continue_on_error)
116-
log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger())
117-
return insert_response
115+
try:
116+
insert_response = parse_insert_response(api_response, request.continue_on_error)
117+
log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger())
118+
return insert_response
119+
finally:
120+
api_response.close()
118121

119122
except Exception as e:
120123
log_error_log(SkyflowMessages.ErrorLogs.INSERT_RECORDS_REJECTED.value, self.__vault_client.get_logger())
@@ -239,9 +242,12 @@ def detokenize(self, request: DetokenizeRequest):
239242
continue_on_error = request.continue_on_error,
240243
request_options=self.__get_headers()
241244
)
242-
log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger())
243-
detokenize_response = parse_detokenize_response(api_response)
244-
return detokenize_response
245+
try:
246+
log_info(SkyflowMessages.Info.DETOKENIZE_SUCCESS.value, self.__vault_client.get_logger())
247+
detokenize_response = parse_detokenize_response(api_response)
248+
return detokenize_response
249+
finally:
250+
api_response.close()
245251
except Exception as e:
246252
log_error_log(SkyflowMessages.ErrorLogs.DETOKENIZE_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger())
247253
handle_exception(e, self.__vault_client.get_logger())
@@ -287,13 +293,16 @@ def upload_file(self, request: FileUploadRequest):
287293
return_file_metadata= False,
288294
request_options=self.__get_headers()
289295
)
290-
log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
291-
log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger())
292-
upload_response = FileUploadResponse(
293-
skyflow_id=api_response.data.skyflow_id,
294-
errors=None
295-
)
296-
return upload_response
296+
try:
297+
log_info(SkyflowMessages.Info.FILE_UPLOAD_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
298+
log_info(SkyflowMessages.Info.FILE_UPLOAD_SUCCESS.value, self.__vault_client.get_logger())
299+
upload_response = FileUploadResponse(
300+
skyflow_id=api_response.data.skyflow_id,
301+
errors=None
302+
)
303+
return upload_response
304+
finally:
305+
api_response.close()
297306
except Exception as e:
298307
log_error_log(SkyflowMessages.ErrorLogs.FILE_UPLOAD_REQUEST_REJECTED.value, logger = self.__vault_client.get_logger())
299308
handle_exception(e, self.__vault_client.get_logger())

0 commit comments

Comments
 (0)