Skip to content

Commit a7de2f0

Browse files
authored
Merge pull request #230 from skyflowapi/aadarsh-st/SK-2521-Fix-python-issue
SK-2521-Refactor file handling and error management in service accoun…
2 parents 5eb3da9 + 72d7262 commit a7de2f0

File tree

4 files changed

+137
-93
lines changed

4 files changed

+137
-93
lines changed

skyflow/service_account/_utils.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,15 @@ def is_expired(token, logger = None):
3333
def generate_bearer_token(credentials_file_path, options = None, logger = None):
3434
try:
3535
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger)
36-
credentials_file =open(credentials_file_path, 'r')
37-
except Exception:
36+
with open(credentials_file_path, 'r') as credentials_file:
37+
try:
38+
credentials = json.load(credentials_file)
39+
except Exception:
40+
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
41+
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)
42+
except OSError:
3843
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)
39-
40-
try:
41-
credentials = json.load(credentials_file)
42-
except Exception:
43-
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
44-
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)
45-
46-
finally:
47-
credentials_file.close()
44+
4845
result = get_service_account_token(credentials, options, logger)
4946
return result
5047

@@ -144,19 +141,15 @@ def get_signed_tokens(credentials_obj, options):
144141
def generate_signed_data_tokens(credentials_file_path, options):
145142
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value)
146143
try:
147-
credentials_file =open(credentials_file_path, 'r')
148-
except Exception:
144+
with open(credentials_file_path, 'r') as credentials_file:
145+
try:
146+
credentials = json.load(credentials_file)
147+
except Exception:
148+
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
149+
invalid_input_error_code)
150+
except FileNotFoundError:
149151
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)
150152

151-
try:
152-
credentials = json.load(credentials_file)
153-
except Exception:
154-
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
155-
invalid_input_error_code)
156-
157-
finally:
158-
credentials_file.close()
159-
160153
return get_signed_tokens(credentials, options)
161154

162155
def generate_signed_data_tokens_from_creds(credentials, options):

skyflow/vault/controller/_connections.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ def invoke(self, request: InvokeConnectionRequest):
3030

3131
log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger())
3232

33+
response = None
3334
try:
3435
response = session.send(invoke_connection_request)
35-
session.close()
36-
invoke_connection_response = parse_invoke_connection_response(response)
37-
return invoke_connection_response
36+
return parse_invoke_connection_response(response)
3837

3938
except Exception as e:
4039
log_error_log(SkyflowMessages.ErrorLogs.INVOKE_CONNECTION_REQUEST_REJECTED.value, self.__vault_client.get_logger())
4140
if isinstance(e, SkyflowError): raise e
4241
raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value,
43-
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
42+
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
43+
finally:
44+
if response is not None:
45+
response.close()
46+
session.close()

skyflow/vault/controller/_detect.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,26 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
6363
current_wait_time = 1 # Start with 1 second
6464
try:
6565
while True:
66-
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
67-
status = response.status
68-
if status == DetectStatus.IN_PROGRESS:
69-
if current_wait_time >= max_wait_time:
70-
return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)
71-
else:
72-
next_wait_time = current_wait_time * 2
73-
if next_wait_time >= max_wait_time:
74-
wait_time = max_wait_time - current_wait_time
75-
current_wait_time = max_wait_time
66+
http_response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers())
67+
try:
68+
response = http_response.data
69+
status = response.status
70+
if status == DetectStatus.IN_PROGRESS:
71+
if current_wait_time >= max_wait_time:
72+
return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)
7673
else:
77-
wait_time = next_wait_time
78-
current_wait_time = next_wait_time
79-
time.sleep(wait_time)
80-
elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED:
81-
return response
74+
next_wait_time = current_wait_time * 2
75+
if next_wait_time >= max_wait_time:
76+
wait_time = max_wait_time - current_wait_time
77+
current_wait_time = max_wait_time
78+
else:
79+
wait_time = next_wait_time
80+
current_wait_time = next_wait_time
81+
time.sleep(wait_time)
82+
elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED:
83+
return response
84+
finally:
85+
http_response.close()
8286
except Exception as e:
8387
raise e
8488

@@ -231,9 +235,12 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo
231235
transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS],
232236
request_options=self.__get_headers()
233237
)
234-
deidentify_text_response = parse_deidentify_text_response(api_response)
235-
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
236-
return deidentify_text_response
238+
try:
239+
deidentify_text_response = parse_deidentify_text_response(api_response)
240+
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
241+
return deidentify_text_response
242+
finally:
243+
api_response.close()
237244

238245
except Exception as e:
239246
log_error_log(SkyflowMessages.ErrorLogs.DEIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger())
@@ -255,9 +262,12 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo
255262
format=reidentify_text_body[DeidentifyField.FORMAT],
256263
request_options=self.__get_headers()
257264
)
258-
reidentify_text_response = parse_reidentify_text_response(api_response)
259-
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
260-
return reidentify_text_response
265+
try:
266+
reidentify_text_response = parse_reidentify_text_response(api_response)
267+
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
268+
return reidentify_text_response
269+
finally:
270+
api_response.close()
261271

262272
except Exception as e:
263273
log_error_log(SkyflowMessages.ErrorLogs.REIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger())
@@ -272,7 +282,7 @@ def __get_file_from_request(self, request: DeidentifyFileRequest):
272282

273283
# Check for file_path if file is not provided
274284
if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None:
275-
return open(file_input.file_path, 'rb')
285+
return open(file_input.file_path, 'rb')
276286

277287
def deidentify_file(self, request: DeidentifyFileRequest):
278288
log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger())
@@ -282,8 +292,16 @@ def deidentify_file(self, request: DeidentifyFileRequest):
282292
file_obj = self.__get_file_from_request(request)
283293
file_name = getattr(file_obj, FileUploadField.NAME, None)
284294
file_extension = self._get_file_extension(file_name) if file_name else None
285-
file_content = file_obj.read()
286-
base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8)
295+
296+
# Track if we need to close the file (only if it was opened from file_path)
297+
file_needs_closing = hasattr(request.file, 'file_path') and request.file.file_path is not None
298+
299+
try:
300+
file_content = file_obj.read()
301+
base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8)
302+
finally:
303+
if file_needs_closing and hasattr(file_obj, 'close'):
304+
file_obj.close()
287305

288306
try:
289307
if file_extension == FileExtension.TXT:
@@ -421,16 +439,19 @@ def deidentify_file(self, request: DeidentifyFileRequest):
421439
log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
422440
api_response = api_call(**api_kwargs)
423441

424-
run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None)
442+
try:
443+
run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None)
425444

426-
processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
427-
if request.output_directory and processed_response.status == DetectStatus.SUCCESS:
428-
name_without_ext, _ = os.path.splitext(file_name)
429-
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)
445+
processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
446+
if request.output_directory and processed_response.status == DetectStatus.SUCCESS:
447+
name_without_ext, _ = os.path.splitext(file_name)
448+
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)
430449

431-
parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
432-
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
433-
return parsed_response
450+
parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
451+
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
452+
return parsed_response
453+
finally:
454+
api_response.close()
434455

435456
except Exception as e:
436457
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
@@ -446,17 +467,20 @@ def get_detect_run(self, request: GetDetectRunRequest):
446467
files_api = self.__vault_client.get_detect_file_api().with_raw_response
447468
run_id = request.run_id
448469
try:
449-
response = files_api.get_run(
470+
http_response = files_api.get_run(
450471
run_id,
451472
vault_id=self.__vault_client.get_vault_id(),
452473
request_options=self.__get_headers()
453474
)
454-
if response.data.status == DetectStatus.IN_PROGRESS:
455-
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS))
456-
else:
457-
parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status)
458-
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
459-
return parsed_response
475+
try:
476+
if http_response.data.status == DetectStatus.IN_PROGRESS:
477+
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS))
478+
else:
479+
parsed_response = self.__parse_deidentify_file_response(http_response.data, run_id, http_response.data.status)
480+
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
481+
return parsed_response
482+
finally:
483+
http_response.close()
460484
except Exception as e:
461485
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
462486
self.__vault_client.get_logger())

0 commit comments

Comments
 (0)