Skip to content
37 changes: 15 additions & 22 deletions skyflow/service_account/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,15 @@ def is_expired(token, logger = None):
def generate_bearer_token(credentials_file_path, options = None, logger = None):
try:
log_info(SkyflowMessages.Info.GET_BEARER_TOKEN_TRIGGERED.value, logger)
credentials_file =open(credentials_file_path, 'r')
except Exception:
with open(credentials_file_path, 'r') as credentials_file:
try:
credentials = json.load(credentials_file)
except Exception:
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)
except OSError:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)

try:
credentials = json.load(credentials_file)
except Exception:
log_error_log(SkyflowMessages.ErrorLogs.INVALID_CREDENTIALS_FILE.value, logger = logger)
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path), invalid_input_error_code)

finally:
credentials_file.close()

result = get_service_account_token(credentials, options, logger)
return result

Expand Down Expand Up @@ -144,19 +141,15 @@ def get_signed_tokens(credentials_obj, options):
def generate_signed_data_tokens(credentials_file_path, options):
log_info(SkyflowMessages.Info.GET_SIGNED_DATA_TOKENS_TRIGGERED.value)
try:
credentials_file =open(credentials_file_path, 'r')
except Exception:
with open(credentials_file_path, 'r') as credentials_file:
try:
credentials = json.load(credentials_file)
except Exception:
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
invalid_input_error_code)
except FileNotFoundError:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value, invalid_input_error_code)

try:
credentials = json.load(credentials_file)
except Exception:
raise SkyflowError(SkyflowMessages.Error.FILE_INVALID_JSON.value.format(credentials_file_path),
invalid_input_error_code)

finally:
credentials_file.close()

return get_signed_tokens(credentials, options)

def generate_signed_data_tokens_from_creds(credentials, options):
Expand Down
11 changes: 7 additions & 4 deletions skyflow/vault/controller/_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ def invoke(self, request: InvokeConnectionRequest):

log_info(SkyflowMessages.Info.INVOKE_CONNECTION_TRIGGERED.value, self.__vault_client.get_logger())

response = None
try:
response = session.send(invoke_connection_request)
session.close()
invoke_connection_response = parse_invoke_connection_response(response)
return invoke_connection_response
return parse_invoke_connection_response(response)

except Exception as e:
log_error_log(SkyflowMessages.ErrorLogs.INVOKE_CONNECTION_REQUEST_REJECTED.value, self.__vault_client.get_logger())
if isinstance(e, SkyflowError): raise e
raise SkyflowError(SkyflowMessages.Error.INVOKE_CONNECTION_FAILED.value,
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
SkyflowMessages.ErrorCodes.SERVER_ERROR.value)
finally:
if response is not None:
response.close()
session.close()
102 changes: 63 additions & 39 deletions skyflow/vault/controller/_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,26 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64):
current_wait_time = 1 # Start with 1 second
try:
while True:
response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers()).data
status = response.status
if status == DetectStatus.IN_PROGRESS:
if current_wait_time >= max_wait_time:
return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)
else:
next_wait_time = current_wait_time * 2
if next_wait_time >= max_wait_time:
wait_time = max_wait_time - current_wait_time
current_wait_time = max_wait_time
http_response = files_api.get_run(run_id, vault_id=self.__vault_client.get_vault_id(), request_options=self.__get_headers())
try:
response = http_response.data
status = response.status
if status == DetectStatus.IN_PROGRESS:
if current_wait_time >= max_wait_time:
return DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS)
else:
wait_time = next_wait_time
current_wait_time = next_wait_time
time.sleep(wait_time)
elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED:
return response
next_wait_time = current_wait_time * 2
if next_wait_time >= max_wait_time:
wait_time = max_wait_time - current_wait_time
current_wait_time = max_wait_time
else:
wait_time = next_wait_time
current_wait_time = next_wait_time
time.sleep(wait_time)
elif status == DetectStatus.SUCCESS or status == DetectStatus.FAILED:
return response
finally:
http_response.close()
except Exception as e:
raise e

Expand Down Expand Up @@ -231,9 +235,12 @@ def deidentify_text(self, request: DeidentifyTextRequest) -> DeidentifyTextRespo
transformations=deidentify_text_body[DeidentifyField.TRANSFORMATIONS],
request_options=self.__get_headers()
)
deidentify_text_response = parse_deidentify_text_response(api_response)
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
return deidentify_text_response
try:
deidentify_text_response = parse_deidentify_text_response(api_response)
log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
return deidentify_text_response
finally:
api_response.close()

except Exception as e:
log_error_log(SkyflowMessages.ErrorLogs.DEIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger())
Expand All @@ -255,9 +262,12 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo
format=reidentify_text_body[DeidentifyField.FORMAT],
request_options=self.__get_headers()
)
reidentify_text_response = parse_reidentify_text_response(api_response)
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
return reidentify_text_response
try:
reidentify_text_response = parse_reidentify_text_response(api_response)
log_info(SkyflowMessages.Info.REIDENTIFY_TEXT_SUCCESS.value, self.__vault_client.get_logger())
return reidentify_text_response
finally:
api_response.close()

except Exception as e:
log_error_log(SkyflowMessages.ErrorLogs.REIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger())
Expand All @@ -272,7 +282,7 @@ def __get_file_from_request(self, request: DeidentifyFileRequest):

# Check for file_path if file is not provided
if hasattr(file_input, FileUploadField.FILE_PATH) and file_input.file_path is not None:
return open(file_input.file_path, 'rb')
return open(file_input.file_path, 'rb')

def deidentify_file(self, request: DeidentifyFileRequest):
log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger())
Expand All @@ -282,8 +292,16 @@ def deidentify_file(self, request: DeidentifyFileRequest):
file_obj = self.__get_file_from_request(request)
file_name = getattr(file_obj, FileUploadField.NAME, None)
file_extension = self._get_file_extension(file_name) if file_name else None
file_content = file_obj.read()
base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8)

# Track if we need to close the file (only if it was opened from file_path)
file_needs_closing = hasattr(request.file, 'file_path') and request.file.file_path is not None

try:
file_content = file_obj.read()
base64_string = base64.b64encode(file_content).decode(EncodingType.UTF_8)
finally:
if file_needs_closing and hasattr(file_obj, 'close'):
file_obj.close()

try:
if file_extension == FileExtension.TXT:
Expand Down Expand Up @@ -421,16 +439,19 @@ def deidentify_file(self, request: DeidentifyFileRequest):
log_info(SkyflowMessages.Info.DETECT_FILE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
api_response = api_call(**api_kwargs)

run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None)
try:
run_id = getattr(api_response.data, DeidentifyField.RUN_ID, None)

processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
if request.output_directory and processed_response.status == DetectStatus.SUCCESS:
name_without_ext, _ = os.path.splitext(file_name)
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)
processed_response = self.__poll_for_processed_file(run_id, request.wait_time)
if request.output_directory and processed_response.status == DetectStatus.SUCCESS:
name_without_ext, _ = os.path.splitext(file_name)
self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext)

parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
return parsed_response
parsed_response = self.__parse_deidentify_file_response(processed_response, run_id)
log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger())
return parsed_response
finally:
api_response.close()

except Exception as e:
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
Expand All @@ -446,17 +467,20 @@ def get_detect_run(self, request: GetDetectRunRequest):
files_api = self.__vault_client.get_detect_file_api().with_raw_response
run_id = request.run_id
try:
response = files_api.get_run(
http_response = files_api.get_run(
run_id,
vault_id=self.__vault_client.get_vault_id(),
request_options=self.__get_headers()
)
if response.data.status == DetectStatus.IN_PROGRESS:
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS))
else:
parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status)
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
return parsed_response
try:
if http_response.data.status == DetectStatus.IN_PROGRESS:
parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status=DetectStatus.IN_PROGRESS))
else:
parsed_response = self.__parse_deidentify_file_response(http_response.data, run_id, http_response.data.status)
log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger())
return parsed_response
finally:
http_response.close()
except Exception as e:
log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value,
self.__vault_client.get_logger())
Expand Down
Loading
Loading