Skip to content

connection creation in extract and CancelledError handling for sse #584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions backend/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,35 +165,36 @@ async def extract_knowledge_graph_from_file(
Nodes and Relations created in Neo4j databse for the pdf file
"""
try:
graph = create_graph_database_connection(uri, userName, password, database)
graphDb_data_Access = graphDBdataAccess(graph)
if source_type == 'local file':
merged_file_path = os.path.join(MERGED_DIR,file_name)
logging.info(f'File path:{merged_file_path}')
result = await asyncio.to_thread(
extract_graph_from_file_local_file, graph, model, merged_file_path, file_name, allowedNodes, allowedRelationship, uri)
extract_graph_from_file_local_file, uri, userName, password, database, model, merged_file_path, file_name, allowedNodes, allowedRelationship)

elif source_type == 's3 bucket' and source_url:
result = await asyncio.to_thread(
extract_graph_from_file_s3, graph, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship)
extract_graph_from_file_s3, uri, userName, password, database, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship)

elif source_type == 'web-url':
result = await asyncio.to_thread(
extract_graph_from_web_page, graph, model, source_url, allowedNodes, allowedRelationship)
extract_graph_from_web_page, uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship)

elif source_type == 'youtube' and source_url:
result = await asyncio.to_thread(
extract_graph_from_file_youtube, graph, model, source_url, allowedNodes, allowedRelationship)
extract_graph_from_file_youtube, uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship)

elif source_type == 'Wikipedia' and wiki_query:
result = await asyncio.to_thread(
extract_graph_from_file_Wikipedia, graph, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship)
extract_graph_from_file_Wikipedia, uri, userName, password, database, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship)

elif source_type == 'gcs bucket' and gcs_bucket_name:
result = await asyncio.to_thread(
extract_graph_from_file_gcs, graph, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship)
extract_graph_from_file_gcs, uri, userName, password, database, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship)
else:
return create_api_response('Failed',message='source_type is other than accepted source')

graph = create_graph_database_connection(uri, userName, password, database)
graphDb_data_Access = graphDBdataAccess(graph)
if result is not None:
result['db_url'] = uri
result['api_name'] = 'extract'
Expand Down Expand Up @@ -445,29 +446,32 @@ async def generate():
if " " in url:
uri= url.replace(" ","+")
while True:
if await request.is_disconnected():
logging.info("Request disconnected")
break
#get the current status of document node
graph = create_graph_database_connection(uri, userName, decoded_password, database)
graphDb_data_Access = graphDBdataAccess(graph)
result = graphDb_data_Access.get_current_status_document_node(file_name)
if result is not None:
status = json.dumps({'fileName':file_name,
'status':result[0]['Status'],
'processingTime':result[0]['processingTime'],
'nodeCount':result[0]['nodeCount'],
'relationshipCount':result[0]['relationshipCount'],
'model':result[0]['model'],
'total_chunks':result[0]['total_chunks'],
'total_pages':result[0]['total_pages'],
'fileSize':result[0]['fileSize'],
'processed_chunk':result[0]['processed_chunk'],
'fileSource':result[0]['fileSource']
})
else:
status = json.dumps({'fileName':file_name, 'status':'Failed'})
yield status
try:
if await request.is_disconnected():
logging.info(" SSE Client disconnected")
break
# get the current status of document node
graph = create_graph_database_connection(uri, userName, decoded_password, database)
graphDb_data_Access = graphDBdataAccess(graph)
result = graphDb_data_Access.get_current_status_document_node(file_name)
if result is not None:
status = json.dumps({'fileName':file_name,
'status':result[0]['Status'],
'processingTime':result[0]['processingTime'],
'nodeCount':result[0]['nodeCount'],
'relationshipCount':result[0]['relationshipCount'],
'model':result[0]['model'],
'total_chunks':result[0]['total_chunks'],
'total_pages':result[0]['total_pages'],
'fileSize':result[0]['fileSize'],
'processed_chunk':result[0]['processed_chunk'],
'fileSource':result[0]['fileSource']
})
else:
status = json.dumps({'fileName':file_name, 'status':'Failed'})
yield status
except asyncio.CancelledError:
logging.info("SSE Connection cancelled")

return EventSourceResponse(generate(),ping=60)

Expand Down
37 changes: 22 additions & 15 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def create_source_node_graph_url_wikipedia(graph, model, wiki_query, source_type
lst_file_name.append({'fileName':obj_source_node.file_name,'fileSize':obj_source_node.file_size,'url':obj_source_node.url, 'language':obj_source_node.language, 'status':'Success'})
return lst_file_name,success_count,failed_count

def extract_graph_from_file_local_file(graph, model, merged_file_path, fileName, allowedNodes, allowedRelationship,uri):
def extract_graph_from_file_local_file(uri, userName, password, database, model, merged_file_path, fileName, allowedNodes, allowedRelationship):

logging.info(f'Process file name :{fileName}')
gcs_file_cache = os.environ.get('GCS_FILE_CACHE')
Expand All @@ -194,9 +194,9 @@ def extract_graph_from_file_local_file(graph, model, merged_file_path, fileName,
if pages==None or len(pages)==0:
raise Exception(f'File content is not available for file : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship, True, merged_file_path, uri)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship, True, merged_file_path)

def extract_graph_from_file_s3(graph, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship):
def extract_graph_from_file_s3(uri, userName, password, database, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship):

if(aws_access_key_id==None or aws_secret_access_key==None):
raise Exception('Please provide AWS access and secret keys')
Expand All @@ -207,43 +207,43 @@ def extract_graph_from_file_s3(graph, model, source_url, aws_access_key_id, aws_
if pages==None or len(pages)==0:
raise Exception(f'File content is not available for file : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)

def extract_graph_from_web_page(graph, model, source_url, allowedNodes, allowedRelationship):
def extract_graph_from_web_page(uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship):

file_name, pages = get_documents_from_web_page(source_url)

if pages==None or len(pages)==0:
raise Exception(f'Content is not available for given URL : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)

def extract_graph_from_file_youtube(graph, model, source_url, allowedNodes, allowedRelationship):
def extract_graph_from_file_youtube(uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship):

file_name, pages = get_documents_from_youtube(source_url)

if pages==None or len(pages)==0:
raise Exception(f'Youtube transcript is not available for file : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)

def extract_graph_from_file_Wikipedia(graph, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship):
def extract_graph_from_file_Wikipedia(uri, userName, password, database, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship):

file_name, pages = get_documents_from_Wikipedia(wiki_query, language)
if pages==None or len(pages)==0:
raise Exception(f'Wikipedia page is not available for file : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)

def extract_graph_from_file_gcs(graph, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship):
def extract_graph_from_file_gcs(uri, userName, password, database, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship):

file_name, pages = get_documents_from_gcs(gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token)
if pages==None or len(pages)==0:
raise Exception(f'File content is not available for file : {file_name}')

return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)

def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship, is_uploaded_from_local=None, merged_file_path=None, uri=None):
def processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship, is_uploaded_from_local=None, merged_file_path=None):
"""
Extracts a Neo4jGraph from a PDF file based on the model.

Expand All @@ -260,6 +260,7 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
status and model as attributes.
"""
start_time = datetime.now()
graph = create_graph_database_connection(uri, userName, password, database)
graphDb_data_Access = graphDBdataAccess(graph)

result = graphDb_data_Access.get_current_status_document_node(file_name)
Expand Down Expand Up @@ -309,7 +310,7 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
logging.info('Exit from running loop of processing file')
exit
else:
node_count,rel_count = processing_chunks(selected_chunks,graph,file_name,model,allowedNodes,allowedRelationship,node_count, rel_count)
node_count,rel_count = processing_chunks(selected_chunks,graph,uri, userName, password, database,file_name,model,allowedNodes,allowedRelationship,node_count, rel_count)
end_time = datetime.now()
processed_time = end_time - start_time

Expand Down Expand Up @@ -362,8 +363,14 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
else:
logging.info('File does not process because it\'s already in Processing status')

def processing_chunks(chunkId_chunkDoc_list,graph,file_name,model,allowedNodes,allowedRelationship, node_count, rel_count):
def processing_chunks(chunkId_chunkDoc_list,graph,uri, userName, password, database,file_name,model,allowedNodes,allowedRelationship, node_count, rel_count):
#create vector index and update chunk node with embedding
if graph is not None:
if graph._driver._closed:
graph = create_graph_database_connection(uri, userName, password, database)
else:
graph = create_graph_database_connection(uri, userName, password, database)

update_embedding_create_vector_index( graph, chunkId_chunkDoc_list, file_name)
logging.info("Get graph document list from models")
graph_documents = generate_graphDocuments(model, graph, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)
Expand Down