11import asyncio
2+ import gradio as gr
23
34from tqdm .asyncio import tqdm as tqdm_async
4- import gradio as gr
55
66from graphgen .models import OpenAIModel , NetworkXStorage , TraverseStrategy , Tokenizer , JsonKVStorage
77from graphgen .templates import ANSWER_REPHRASING_PROMPT , QUESTION_GENERATION_PROMPT , MULTI_HOP_GENERATION_PROMPT
@@ -53,7 +53,6 @@ async def handle_node(node: dict) -> dict:
5353
5454async def _construct_rephrasing_prompt (_process_nodes : list ,
5555 _process_edges : list ,
56- _difficulty : str ,
5756 text_chunks_storage : JsonKVStorage ,
5857 add_context : bool = False
5958 ) -> str :
@@ -77,15 +76,15 @@ async def _construct_rephrasing_prompt(_process_nodes: list,
7776 original_text = await text_chunks_storage .get_by_ids (original_ids )
7877 original_text = "\n " .join ([f"{ index + 1 } . { text ['content' ]} " for index , text in enumerate (original_text )])
7978
80- prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][ language ]['CONTEXT_TEMPLATE' ].format (
79+ prompt = ANSWER_REPHRASING_PROMPT [language ]['CONTEXT_TEMPLATE' ].format (
8180 language = language ,
8281 original_text = original_text ,
8382 entities = entities_str ,
8483 relationships = relations_str
8584 )
8685 return prompt
8786
88- prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][ language ]['TEMPLATE' ].format (
87+ prompt = ANSWER_REPHRASING_PROMPT [language ]['TEMPLATE' ].format (
8988 language = language ,
9089 entities = entities_str ,
9190 relationships = relations_str
@@ -99,34 +98,6 @@ def get_loss_tercile(losses: list) -> (float, float):
9998
10099 return losses [q1_index ], losses [q2_index ]
101100
102- def assign_difficulty (subgraphs : list , difficulty_order : list , loss_strategy : str ) -> list :
103- """
104- Assign difficulty to subgraphs based on the loss.
105-
106- :param subgraphs
107- :param difficulty_order
108- :param loss_strategy
109- :return
110- """
111- losses = []
112- for subgraph in subgraphs :
113- loss = get_average_loss (subgraph , loss_strategy )
114- losses .append (loss )
115- q1 , q2 = get_loss_tercile (losses )
116-
117- for i , subgraph in enumerate (subgraphs ):
118- loss = get_average_loss (subgraph , loss_strategy )
119- if loss < q1 :
120- # easy
121- subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [0 ])
122- elif loss < q2 :
123- # medium
124- subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [1 ])
125- else :
126- # hard
127- subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [2 ])
128- return subgraphs
129-
130101def get_average_loss (batch : tuple , loss_strategy : str ) -> float :
131102 if loss_strategy == "only_edge" :
132103 return sum (edge [2 ]['loss' ] for edge in batch [1 ]) / len (batch [1 ])
@@ -179,7 +150,7 @@ async def traverse_graph_by_edge(
179150 :param graph_storage
180151 :param traverse_strategy
181152 :param text_chunks_storage
182- :param progress_bar: gradio progress bar
153+ :param progress_bar
183154 :param max_concurrent
184155 :return: question and answer
185156 """
@@ -189,12 +160,10 @@ async def traverse_graph_by_edge(
189160 async def _process_nodes_and_edges (
190161 _process_nodes : list ,
191162 _process_edges : list ,
192- _difficulty : str ,
193163 ) -> str :
194164 prompt = await _construct_rephrasing_prompt (
195165 _process_nodes ,
196166 _process_edges ,
197- _difficulty ,
198167 text_chunks_storage ,
199168 add_context = False
200169 )
@@ -216,68 +185,48 @@ async def _process_single_batch(
216185 context = await _process_nodes_and_edges (
217186 _process_batch [0 ],
218187 _process_batch [1 ],
219- _process_batch [2 ]
220188 )
221- # 一般第一行就是Question
222- # 后面的都是Answer
223- question = context .split ("\n " )[0 ]
224- for prefix in ["Question:" , "问题:" , "问题:" ]:
225- if question .startswith (prefix ):
226- question = question [len (prefix ):].strip ()
227- break
228- answer = "\n " .join (context .split ("\n " )[1 :]).strip ()
229- for prefix in ["Answer:" , "答案:" ,"答案:" , "回答:" , "回答:" ]:
230- if answer .startswith (prefix ):
231- answer = answer [len (prefix ):].strip ()
232- break
233- qas = [
234- {
235- "question" : question ,
236- "answer" : answer
237- }
238- ]
239189
240190 language = "Chinese" if detect_main_language (context ) == "zh" else "English"
241191 pre_length = sum (node ['length' ] for node in _process_batch [0 ]) \
242192 + sum (edge [2 ]['length' ] for edge in _process_batch [1 ])
243193
244- # if question_type == "single":
245- # question = await llm_client.generate_answer(
246- # QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
247- # answer=context
248- # )
249- # )
250- # if question.startswith("Question:"):
251- # question = question[len("Question:"):].strip()
252- # elif question.startswith("问题:"):
253- # question = question[len("问题:"):].strip()
254- #
255- # logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
256- # logger.info("Pre-length: %s", pre_length)
257- # logger.info("Question: %s", question)
258- # logger.info("Answer: %s", context)
259- #
260- # return {
261- # compute_content_hash(context): {
262- # "question": question,
263- # "answer": context,
264- # "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
265- # "difficulty": _process_batch[2],
266- # }
267- # }
268- #
269- # content = await llm_client.generate_answer(
270- # QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
271- # doc=context
272- # )
273- # )
274- # qas = _post_process_synthetic_data(content)
275- #
276- # if len(qas) == 0:
277- # print(content)
278- # logger.error("Error occurred while processing batch, question or answer is None")
279- # return {}
280- #
194+ if question_type == "single" :
195+ question = await llm_client .generate_answer (
196+ QUESTION_GENERATION_PROMPT [language ]['SINGLE_TEMPLATE' ].format (
197+ answer = context
198+ )
199+ )
200+ if question .startswith ("Question:" ):
201+ question = question [len ("Question:" ):].strip ()
202+ elif question .startswith ("问题:" ):
203+ question = question [len ("问题:" ):].strip ()
204+
205+ logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
206+ logger .info ("Pre-length: %s" , pre_length )
207+ logger .info ("Question: %s" , question )
208+ logger .info ("Answer: %s" , context )
209+
210+ return {
211+ compute_content_hash (context ): {
212+ "question" : question ,
213+ "answer" : context ,
214+ "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy )
215+ }
216+ }
217+
218+ content = await llm_client .generate_answer (
219+ QUESTION_GENERATION_PROMPT [language ]['MULTI_TEMPLATE' ].format (
220+ doc = context
221+ )
222+ )
223+ qas = _post_process_synthetic_data (content )
224+
225+ if len (qas ) == 0 :
226+ print (content )
227+ logger .error ("Error occurred while processing batch, question or answer is None" )
228+ return {}
229+
281230 final_results = {}
282231 logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
283232 logger .info ("Pre-length: %s" , pre_length )
@@ -287,8 +236,7 @@ async def _process_single_batch(
287236 final_results [compute_content_hash (qa ['question' ])] = {
288237 "question" : qa ['question' ],
289238 "answer" : qa ['answer' ],
290- "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy ),
291- "difficulty" : _process_batch [2 ],
239+ "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy )
292240 }
293241 return final_results
294242
@@ -305,16 +253,17 @@ async def _process_single_batch(
305253 traverse_strategy
306254 )
307255
308- processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
309- traverse_strategy .loss_strategy )
310-
311256 for result in tqdm_async (asyncio .as_completed (
312257 [_process_single_batch (batch ) for batch in processing_batches ]
313- ), total = len (processing_batches ), desc = "Processing batches " ):
258+ ), total = len (processing_batches ), desc = "[4/4]Generating QAs " ):
314259 try :
260+ if progress_bar is not None :
261+ progress_bar (len (results ) / len (processing_batches ), desc = "[4/4]Generating QAs" )
315262 results .update (await result )
263+ if progress_bar is not None and len (results ) == len (processing_batches ):
264+ progress_bar (1 , desc = "[4/4]Generating QAs" )
316265 except Exception as e : # pylint: disable=broad-except
317- logger .error ("Error occurred while processing batches : %s" , e )
266+ logger .error ("Error occurred while generating QA : %s" , e )
318267
319268 return results
320269
@@ -336,7 +285,7 @@ async def traverse_graph_atomically(
336285 :param graph_storage
337286 :param traverse_strategy
338287 :param text_chunks_storage
339- :param progress_bar: gradio progress bar
288+ :param progress_bar
340289 :param max_concurrent
341290 :return: question and answer
342291 """
@@ -381,8 +330,7 @@ async def _generate_question(
381330 compute_content_hash (question ): {
382331 "question" : question ,
383332 "answer" : answer ,
384- "loss" : loss ,
385- "difficulty" : "medium"
333+ "loss" : loss
386334 }
387335 }
388336 except Exception as e : # pylint: disable=broad-except
@@ -414,12 +362,16 @@ async def _generate_question(
414362 for result in tqdm_async (
415363 asyncio .as_completed ([_generate_question (task ) for task in tasks ]),
416364 total = len (tasks ),
417- desc = "Generating questions "
365+ desc = "[4/4] Generating QAs "
418366 ):
419367 try :
368+ if progress_bar is not None :
369+ progress_bar (len (results ) / len (tasks ), desc = "[4/4]Generating QAs" )
420370 results .update (await result )
371+ if progress_bar is not None and len (results ) == len (tasks ):
372+ progress_bar (1 , desc = "[4/4]Generating QAs" )
421373 except Exception as e : # pylint: disable=broad-except
422- logger .error ("Error occurred while generating questions : %s" , e )
374+ logger .error ("Error occurred while generating QA : %s" , e )
423375 return results
424376
425377async def traverse_graph_for_multi_hop (
@@ -439,7 +391,7 @@ async def traverse_graph_for_multi_hop(
439391 :param graph_storage
440392 :param traverse_strategy
441393 :param text_chunks_storage
442- :param progress_bar: gradio progress bar
394+ :param progress_bar
443395 :param max_concurrent
444396 :return: question and answer
445397 """
@@ -460,9 +412,6 @@ async def traverse_graph_for_multi_hop(
460412 traverse_strategy
461413 )
462414
463- processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
464- traverse_strategy .loss_strategy )
465-
466415 async def _process_single_batch (
467416 _process_batch : tuple
468417 ) -> dict :
@@ -513,21 +462,24 @@ async def _process_single_batch(
513462 "question" : question ,
514463 "answer" : answer ,
515464 "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy ),
516- "difficulty" : _process_batch [2 ],
517465 }
518466 }
519467
520468 except Exception as e : # pylint: disable=broad-except
521469 logger .error ("Error occurred while processing batch: %s" , e )
522470 return {}
523471
524- for result in tqdm_async (
472+ async for result in tqdm_async (
525473 asyncio .as_completed ([_process_single_batch (batch ) for batch in processing_batches ]),
526474 total = len (processing_batches ),
527- desc = "Processing batches "
475+ desc = "[4/4]Generating QAs "
528476 ):
529477 try :
478+ if progress_bar is not None :
479+ progress_bar (len (results ) / len (processing_batches ), desc = "[4/4]Generating QAs" )
530480 results .update (await result )
481+ if progress_bar is not None and len (results ) == len (processing_batches ):
482+ progress_bar (1 , desc = "[4/4]Generating QAs" )
531483 except Exception as e : # pylint: disable=broad-except
532- logger .error ("Error occurred while processing batches : %s" , e )
484+ logger .error ("Error occurred while generating QA : %s" , e )
533485 return results
0 commit comments