Skip to content

Commit c883cf2

Browse files
Merge pull request #39 from open-sciencelab/revert-35-cot
2 parents 621b151 + c12bfde commit c883cf2

File tree

2 files changed

+89
-276
lines changed

2 files changed

+89
-276
lines changed

graphgen/operators/traverse_graph.py

Lines changed: 63 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
2+
import gradio as gr
23

34
from tqdm.asyncio import tqdm as tqdm_async
4-
import gradio as gr
55

66
from graphgen.models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
77
from graphgen.templates import ANSWER_REPHRASING_PROMPT, QUESTION_GENERATION_PROMPT, MULTI_HOP_GENERATION_PROMPT
@@ -53,7 +53,6 @@ async def handle_node(node: dict) -> dict:
5353

5454
async def _construct_rephrasing_prompt(_process_nodes: list,
5555
_process_edges: list,
56-
_difficulty: str,
5756
text_chunks_storage: JsonKVStorage,
5857
add_context: bool = False
5958
) -> str:
@@ -77,15 +76,15 @@ async def _construct_rephrasing_prompt(_process_nodes: list,
7776
original_text = await text_chunks_storage.get_by_ids(original_ids)
7877
original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
7978

80-
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['CONTEXT_TEMPLATE'].format(
79+
prompt = ANSWER_REPHRASING_PROMPT[language]['CONTEXT_TEMPLATE'].format(
8180
language=language,
8281
original_text=original_text,
8382
entities=entities_str,
8483
relationships=relations_str
8584
)
8685
return prompt
8786

88-
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['TEMPLATE'].format(
87+
prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
8988
language=language,
9089
entities=entities_str,
9190
relationships=relations_str
@@ -99,34 +98,6 @@ def get_loss_tercile(losses: list) -> (float, float):
9998

10099
return losses[q1_index], losses[q2_index]
101100

102-
def assign_difficulty(subgraphs: list, difficulty_order: list, loss_strategy: str) -> list:
103-
"""
104-
Assign difficulty to subgraphs based on the loss.
105-
106-
:param subgraphs
107-
:param difficulty_order
108-
:param loss_strategy
109-
:return
110-
"""
111-
losses = []
112-
for subgraph in subgraphs:
113-
loss = get_average_loss(subgraph, loss_strategy)
114-
losses.append(loss)
115-
q1, q2 = get_loss_tercile(losses)
116-
117-
for i, subgraph in enumerate(subgraphs):
118-
loss = get_average_loss(subgraph, loss_strategy)
119-
if loss < q1:
120-
# easy
121-
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[0])
122-
elif loss < q2:
123-
# medium
124-
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[1])
125-
else:
126-
# hard
127-
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[2])
128-
return subgraphs
129-
130101
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
131102
if loss_strategy == "only_edge":
132103
return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
@@ -179,7 +150,7 @@ async def traverse_graph_by_edge(
179150
:param graph_storage
180151
:param traverse_strategy
181152
:param text_chunks_storage
182-
:param progress_bar: gradio progress bar
153+
:param progress_bar
183154
:param max_concurrent
184155
:return: question and answer
185156
"""
@@ -189,12 +160,10 @@ async def traverse_graph_by_edge(
189160
async def _process_nodes_and_edges(
190161
_process_nodes: list,
191162
_process_edges: list,
192-
_difficulty: str,
193163
) -> str:
194164
prompt = await _construct_rephrasing_prompt(
195165
_process_nodes,
196166
_process_edges,
197-
_difficulty,
198167
text_chunks_storage,
199168
add_context = False
200169
)
@@ -216,68 +185,48 @@ async def _process_single_batch(
216185
context = await _process_nodes_and_edges(
217186
_process_batch[0],
218187
_process_batch[1],
219-
_process_batch[2]
220188
)
221-
# 一般第一行就是Question
222-
# 后面的都是Answer
223-
question = context.split("\n")[0]
224-
for prefix in ["Question:", "问题:", "问题:"]:
225-
if question.startswith(prefix):
226-
question = question[len(prefix):].strip()
227-
break
228-
answer = "\n".join(context.split("\n")[1:]).strip()
229-
for prefix in ["Answer:", "答案:","答案:", "回答:", "回答:"]:
230-
if answer.startswith(prefix):
231-
answer = answer[len(prefix):].strip()
232-
break
233-
qas = [
234-
{
235-
"question": question,
236-
"answer": answer
237-
}
238-
]
239189

240190
language = "Chinese" if detect_main_language(context) == "zh" else "English"
241191
pre_length = sum(node['length'] for node in _process_batch[0]) \
242192
+ sum(edge[2]['length'] for edge in _process_batch[1])
243193

244-
# if question_type == "single":
245-
# question = await llm_client.generate_answer(
246-
# QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
247-
# answer=context
248-
# )
249-
# )
250-
# if question.startswith("Question:"):
251-
# question = question[len("Question:"):].strip()
252-
# elif question.startswith("问题:"):
253-
# question = question[len("问题:"):].strip()
254-
#
255-
# logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
256-
# logger.info("Pre-length: %s", pre_length)
257-
# logger.info("Question: %s", question)
258-
# logger.info("Answer: %s", context)
259-
#
260-
# return {
261-
# compute_content_hash(context): {
262-
# "question": question,
263-
# "answer": context,
264-
# "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
265-
# "difficulty": _process_batch[2],
266-
# }
267-
# }
268-
#
269-
# content = await llm_client.generate_answer(
270-
# QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
271-
# doc=context
272-
# )
273-
# )
274-
# qas = _post_process_synthetic_data(content)
275-
#
276-
# if len(qas) == 0:
277-
# print(content)
278-
# logger.error("Error occurred while processing batch, question or answer is None")
279-
# return {}
280-
#
194+
if question_type == "single":
195+
question = await llm_client.generate_answer(
196+
QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
197+
answer=context
198+
)
199+
)
200+
if question.startswith("Question:"):
201+
question = question[len("Question:"):].strip()
202+
elif question.startswith("问题:"):
203+
question = question[len("问题:"):].strip()
204+
205+
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
206+
logger.info("Pre-length: %s", pre_length)
207+
logger.info("Question: %s", question)
208+
logger.info("Answer: %s", context)
209+
210+
return {
211+
compute_content_hash(context): {
212+
"question": question,
213+
"answer": context,
214+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
215+
}
216+
}
217+
218+
content = await llm_client.generate_answer(
219+
QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
220+
doc=context
221+
)
222+
)
223+
qas = _post_process_synthetic_data(content)
224+
225+
if len(qas) == 0:
226+
print(content)
227+
logger.error("Error occurred while processing batch, question or answer is None")
228+
return {}
229+
281230
final_results = {}
282231
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
283232
logger.info("Pre-length: %s", pre_length)
@@ -287,8 +236,7 @@ async def _process_single_batch(
287236
final_results[compute_content_hash(qa['question'])] = {
288237
"question": qa['question'],
289238
"answer": qa['answer'],
290-
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
291-
"difficulty": _process_batch[2],
239+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
292240
}
293241
return final_results
294242

@@ -305,16 +253,17 @@ async def _process_single_batch(
305253
traverse_strategy
306254
)
307255

308-
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
309-
traverse_strategy.loss_strategy)
310-
311256
for result in tqdm_async(asyncio.as_completed(
312257
[_process_single_batch(batch) for batch in processing_batches]
313-
), total=len(processing_batches), desc="Processing batches"):
258+
), total=len(processing_batches), desc="[4/4]Generating QAs"):
314259
try:
260+
if progress_bar is not None:
261+
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
315262
results.update(await result)
263+
if progress_bar is not None and len(results) == len(processing_batches):
264+
progress_bar(1, desc="[4/4]Generating QAs")
316265
except Exception as e: # pylint: disable=broad-except
317-
logger.error("Error occurred while processing batches: %s", e)
266+
logger.error("Error occurred while generating QA: %s", e)
318267

319268
return results
320269

@@ -336,7 +285,7 @@ async def traverse_graph_atomically(
336285
:param graph_storage
337286
:param traverse_strategy
338287
:param text_chunks_storage
339-
:param progress_bar: gradio progress bar
288+
:param progress_bar
340289
:param max_concurrent
341290
:return: question and answer
342291
"""
@@ -381,8 +330,7 @@ async def _generate_question(
381330
compute_content_hash(question): {
382331
"question": question,
383332
"answer": answer,
384-
"loss": loss,
385-
"difficulty": "medium"
333+
"loss": loss
386334
}
387335
}
388336
except Exception as e: # pylint: disable=broad-except
@@ -414,12 +362,16 @@ async def _generate_question(
414362
for result in tqdm_async(
415363
asyncio.as_completed([_generate_question(task) for task in tasks]),
416364
total=len(tasks),
417-
desc="Generating questions"
365+
desc="[4/4]Generating QAs"
418366
):
419367
try:
368+
if progress_bar is not None:
369+
progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
420370
results.update(await result)
371+
if progress_bar is not None and len(results) == len(tasks):
372+
progress_bar(1, desc="[4/4]Generating QAs")
421373
except Exception as e: # pylint: disable=broad-except
422-
logger.error("Error occurred while generating questions: %s", e)
374+
logger.error("Error occurred while generating QA: %s", e)
423375
return results
424376

425377
async def traverse_graph_for_multi_hop(
@@ -439,7 +391,7 @@ async def traverse_graph_for_multi_hop(
439391
:param graph_storage
440392
:param traverse_strategy
441393
:param text_chunks_storage
442-
:param progress_bar: gradio progress bar
394+
:param progress_bar
443395
:param max_concurrent
444396
:return: question and answer
445397
"""
@@ -460,9 +412,6 @@ async def traverse_graph_for_multi_hop(
460412
traverse_strategy
461413
)
462414

463-
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
464-
traverse_strategy.loss_strategy)
465-
466415
async def _process_single_batch(
467416
_process_batch: tuple
468417
) -> dict:
@@ -513,21 +462,24 @@ async def _process_single_batch(
513462
"question": question,
514463
"answer": answer,
515464
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
516-
"difficulty": _process_batch[2],
517465
}
518466
}
519467

520468
except Exception as e: # pylint: disable=broad-except
521469
logger.error("Error occurred while processing batch: %s", e)
522470
return {}
523471

524-
for result in tqdm_async(
472+
async for result in tqdm_async(
525473
asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
526474
total=len(processing_batches),
527-
desc="Processing batches"
475+
desc="[4/4]Generating QAs"
528476
):
529477
try:
478+
if progress_bar is not None:
479+
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
530480
results.update(await result)
481+
if progress_bar is not None and len(results) == len(processing_batches):
482+
progress_bar(1, desc="[4/4]Generating QAs")
531483
except Exception as e: # pylint: disable=broad-except
532-
logger.error("Error occurred while processing batches: %s", e)
484+
logger.error("Error occurred while generating QA: %s", e)
533485
return results

0 commit comments

Comments
 (0)