1+ import io
12import json
23import os
34from skyflow .error import SkyflowError
2021from skyflow .vault .detect import DeidentifyTextRequest , DeidentifyTextResponse , ReidentifyTextRequest , \
2122 ReidentifyTextResponse , DeidentifyFileRequest , DeidentifyFileResponse , GetDetectRunRequest
2223
24+
2325class Detect :
2426 def __init__ (self , vault_client ):
2527 self .__vault_client = vault_client
@@ -124,10 +126,22 @@ def output_to_dict_list(output):
124126 word_count = getattr (word_character_count , "word_count" , None )
125127 char_count = getattr (word_character_count , "character_count" , None )
126128
129+ base64_string = first_output .get ("file" , None )
130+ extension = first_output .get ("extension" , None )
131+
132+ file_obj = None
133+ if base64_string is not None :
134+ file_bytes = base64 .b64decode (base64_string )
135+ file_obj = io .BytesIO (file_bytes )
136+ file_obj .name = f"deidentified.{ extension } " if extension else "processed_file"
137+ else :
138+ file_obj = None
139+
127140 return DeidentifyFileResponse (
128- file = first_output .get ("file" , None ),
141+ file_base64 = base64_string ,
142+ file = file_obj , # File class will be instantiated in DeidentifyFileResponse
129143 type = first_output .get ("type" , None ),
130- extension = first_output . get ( " extension" , None ) ,
144+ extension = extension ,
131145 word_count = word_count ,
132146 char_count = char_count ,
133147 size_in_kb = size ,
@@ -137,7 +151,7 @@ def output_to_dict_list(output):
137151 entities = entities ,
138152 run_id = run_id_val ,
139153 status = status_val ,
140- errors = []
154+ errors = None
141155 )
142156
143157 def __get_token_format (self , request ):
@@ -216,16 +230,26 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo
216230 log_error_log (SkyflowMessages .ErrorLogs .REIDENTIFY_TEXT_REQUEST_REJECTED .value , self .__vault_client .get_logger ())
217231 handle_exception (e , self .__vault_client .get_logger ())
218232
233+ def __get_file_from_request (self , request : DeidentifyFileRequest ):
234+ file_input = request .file
235+
236+ # Check for file
237+ if hasattr (file_input , 'file' ) and file_input .file is not None :
238+ return file_input .file
239+
240+ # Check for file_path if file is not provided
241+ if hasattr (file_input , 'file_path' ) and file_input .file_path is not None :
242+ return open (file_input .file_path , 'rb' )
243+
219244 def deidentify_file (self , request : DeidentifyFileRequest ):
220245 log_info (SkyflowMessages .Info .DETECT_FILE_TRIGGERED .value , self .__vault_client .get_logger ())
221246 validate_deidentify_file_request (self .__vault_client .get_logger (), request )
222247 self .__initialize ()
223248 files_api = self .__vault_client .get_detect_file_api ().with_raw_response
224- file_obj = request . file
249+ file_obj = self . __get_file_from_request ( request )
225250 file_name = getattr (file_obj , 'name' , None )
226251 file_extension = self ._get_file_extension (file_name ) if file_name else None
227252 file_content = file_obj .read ()
228-
229253 base64_string = base64 .b64encode (file_content ).decode ('utf-8' )
230254
231255 try :
@@ -375,7 +399,7 @@ def deidentify_file(self, request: DeidentifyFileRequest):
375399 file_name_only = 'processed-' + os .path .basename (file_name )
376400 output_file_path = f"{ request .output_directory } /{ file_name_only } "
377401 with open (output_file_path , 'wb' ) as output_file :
378- output_file .write (base64 .b64decode (parsed_response .file ))
402+ output_file .write (base64 .b64decode (parsed_response .file_base64 ))
379403 log_info (SkyflowMessages .Info .DETECT_FILE_SUCCESS .value , self .__vault_client .get_logger ())
380404 return parsed_response
381405
0 commit comments