Skip to content

Commit

Permalink
add benchmakrks with text_grad and dspy
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Jan 22, 2025
1 parent 123cf80 commit 8026d52
Show file tree
Hide file tree
Showing 9 changed files with 644 additions and 108 deletions.
2 changes: 1 addition & 1 deletion adalflow/adalflow/optim/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,7 @@ def _fit_demos_random(
trainer_results,
last_val_score,
test_score=trainer_results.test_scores[-1],
prompts=trainer_results.prompts[-1],
prompts=trainer_results.step_results[-1].prompt,
step=step,
attempted_val_score=val_score,
)
Expand Down
8 changes: 8 additions & 0 deletions benchmarks/BHH_object_count/text_grad/text_grad_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
Text grad's object count implementation:
self._task_description = "You will answer a reasoning question. Think step by step. The last line of your response should be of the following format: 'Answer: $VALUE' where VALUE is a numerical value."
We use the same task description, the only difference is dspy send over a messages: [system_prompt, user_message] and we do ["system": <> system_prompt <> <> user_message<>]
"""

import logging


Expand Down
6 changes: 6 additions & 0 deletions benchmarks/BHH_object_count/text_grad/trec_6_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,9 @@ def multi_run_train(num_runs=4):
# eval_fn=eval_fn,
# )
multi_run_train(num_runs=4)

# Average Training Time: 1383.4382199645042
# Average Test Score: 0.8488023952095808
# Average Val Score: 0.8373493975903614
# Std Test Score: 0.011691990532794386
# Std Val Score: 0.01127005236980105
100 changes: 19 additions & 81 deletions benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,21 @@ class QueriesOutput(adal.DataClass):
)


class DeduplicateList(adal.GradComponent):
def __init__(self):
super().__init__()
# class DeduplicateList(adal.GradComponent):
# def __init__(self):
# super().__init__()

def call(
self, exisiting_list: List[str], new_list: List[str], id: str = None
) -> List[str]:
# def call(
# self, exisiting_list: List[str], new_list: List[str], id: str = None
# ) -> List[str]:

seen = set()
return [x for x in exisiting_list + new_list if not (x in seen or seen.add(x))]
# seen = set()
# return [x for x in exisiting_list + new_list if not (x in seen or seen.add(x))]

def backward(self, *args, **kwargs):
# def backward(self, *args, **kwargs):

printc(f"DeduplicateList backward: {args}", "yellow")
return super().backward(*args, **kwargs)
# printc(f"DeduplicateList backward: {args}", "yellow")
# return super().backward(*args, **kwargs)


class CombineList(GradComponent2):
Expand Down Expand Up @@ -215,7 +215,6 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
)

self.retriever = DspyRetriever(top_k=passages_per_hop)
self.deduplicater = DeduplicateList()
self.combine_list = CombineList()

@staticmethod
Expand Down Expand Up @@ -379,14 +378,14 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
model_client=model_client,
model_kwargs=model_kwargs,
prompt_kwargs={
# "few_shot_demos": Parameter(
# name=f"few_shot_demos_{i}",
# # data=few_shot_demos[i],
# data=None,
# role_desc="To provide few shot demos to the language model",
# requires_opt=True,
# param_type=ParameterType.DEMOS,
# ),
"few_shot_demos": Parameter(
name=f"few_shot_demos_{i}",
# data=few_shot_demos[i],
data=None,
role_desc="To provide few shot demos to the language model",
requires_opt=True,
param_type=ParameterType.DEMOS,
),
"task_desc_str": Parameter(
name="task_desc_str",
data=task_desc_str,
Expand All @@ -411,7 +410,6 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
)
)
self.retrievers.append(DspyRetriever(top_k=passages_per_hop))
self.deduplicaters.append(DeduplicateList())

self.combine_list = CombineList()
self.combine_queries = CombineQueries()
Expand Down Expand Up @@ -456,36 +454,7 @@ def call(self, *, input: str, id: str = None) -> adal.RetrieverOutput:
)
return out

def call2(self, *, input: str, id: str = None) -> str:
context = []
queries: List[str] = []
last_query = None
for i in range(self.max_hops):
gen_out = self.query_generators[i](
prompt_kwargs={
"context": context,
"question": input,
"last_query": last_query,
},
id=id,
)

query = gen_out.data.query if gen_out.data and gen_out.data.query else input

retrieve_out = self.retrievers[i](input=query, id=id)

passages = retrieve_out.documents
context = self.deduplicate(context + passages)
queries.append(query)
last_query = query
out = ", ".join(queries)
query_output = QueriesOutput(data=out, id=id)
return query_output

def forward(self, *, input: str, id: str = None) -> adal.Parameter:
# assemble the foundamental building blocks
# printc(f"question: {input}", "yellow")
# context = []

queries: List[str] = []

Expand Down Expand Up @@ -530,15 +499,6 @@ def forward(self, *, input: str, id: str = None) -> adal.Parameter:
def retrieve_out_map_fn(x: adal.Parameter):
return x.data.documents if x.data and x.data.documents else []

# print(f"retrieve_out: {retrieve_out}")

# retrieve_out.add_successor_map_fn(
# successor=self.deduplicaters[i], map_fn=retrieve_out_map_fn
# )

# context = self.deduplicaters[i].forward(
# exisiting_list=context, new_list=retrieve_out
# )
retrieve_out.data_in_prompt = lambda x: {
"query": x.data.query,
"documents": x.data.documents,
Expand All @@ -550,13 +510,6 @@ def retrieve_out_map_fn(x: adal.Parameter):
)
last_query = success_map_fn(gen_out)
contexts.append(retrieve_out)
# if i + 1 < self.max_hops:
# retrieve_out.add_successor_map_fn(
# successor=self.query_generators[i + 1], map_fn=retrieve_out_map_fn
# )

# last_query = success_map_fn(gen_out)
# printc(f"retrieve_out, last_query: {last_query}")

contexts[0].add_successor_map_fn(
successor=self.combine_list, map_fn=lambda x: x.data
Expand Down Expand Up @@ -645,22 +598,7 @@ def retrieve_out_map_fn(x: adal.Parameter):
)

last_query = success_map_fn(gen_out)
# printc(f"retrieve_out, last_query: {last_query}")

# contexts[0].add_successor_map_fn(
# successor=self.combine_list, map_fn=lambda x: x.data
# )
# contexts[1].add_successor_map_fn(
# successor=self.combine_list, map_fn=lambda x: x.data
# )
# contexts_sum = self.combine_list.forward(
# context_1=contexts[0], context_2=contexts[1]
# )
# contexts_sum.data_in_prompt = lambda x: {
# "query": x.data.query,
# "documents": x.data.documents,
# }
# setattr(contexts_sum, "queries", [q.data.data.query for q in queries])
queries[0].add_successor_map_fn(
successor=self.combine_queries, map_fn=lambda x: x.data.data.query
)
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,13 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):
instruction_to_optimizer="ou need find the best way(where does the right answer come from the context) to extract the RIGHT answer from the context.",
# + "Given existing context, ensure the task instructions can maximize the performance.",
),
# "few_shot_demos": adal.Parameter(
# # data=demo_str,
# data=None,
# requires_opt=True,
# role_desc="To provide few shot demos to the language model",
# param_type=adal.ParameterType.DEMOS,
# ),
"few_shot_demos": adal.Parameter(
# data=demo_str,
data=None,
requires_opt=True,
role_desc="To provide few shot demos to the language model",
param_type=adal.ParameterType.DEMOS,
),
"output_format_str": self.llm_parser.get_output_format_str(),
# "output_format_str": adal.Parameter(
# data=self.llm_parser.get_output_format_str(),
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def train(
# train_diagnose(**gpt_3_model)

ckpt = train(
debug=True,
max_steps=24,
debug=False,
max_steps=12,
seed=2025, # pass the numpy seed
tg=use_tg,
strategy=set_strategy,
Expand Down
74 changes: 57 additions & 17 deletions benchmarks/hotpot_qa/dspy_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class GenerateAnswer(dspy.Signature):
answer = dspy.OutputField(desc="answer to the question")


task_desc_str = """
You will receive an original question, last search query, and the retrieved context from the last search query.
Write the next search query to help retrieve all relevant context to answer the original question.
Think step by step."""


class GenerateSearchQuery(dspy.Signature):
"""Write a simple search query that will help answer a complex question."""

Expand All @@ -53,6 +59,17 @@ class GenerateSearchQuery(dspy.Signature):
query = dspy.OutputField()


# class GenerateSearchQuery(dspy.Signature):
# """You will receive an original question, last search query, and the retrieved context from the last search query.
# Write the next search query to help retrieve all relevant context to answer the original question.
# """

# context = dspy.InputField(desc="may contain relevant facts")
# question = dspy.InputField()
# last_search_query = dspy.InputField(desc="The last search query")
# query = dspy.OutputField()


from dsp.utils import deduplicate


Expand All @@ -69,9 +86,14 @@ def __init__(self, passages_per_hop=2, max_hops=2):

def forward(self, question):
context = []
# last_query = None

for hop in range(self.max_hops):
query = self.generate_query[hop](context=context, question=question).query
query = self.generate_query[hop](
context=context,
question=question, # last_search_query=last_query
).query
# last_query = query
passages = self.retrieve(query).passages
context = deduplicate(context + passages)

Expand Down Expand Up @@ -100,16 +122,17 @@ def train_MIPROv2(trainset, valset, save_path, filename):
metric=validate_answer,
prompt_model=gpt_4,
task_model=turbo,
num_candidates=30,
num_candidates=12,
init_temperature=1.0,
log_dir=save_path,
)
compiled_task = tp.compile(
DsPyMultiHopRAG(),
trainset=trainset,
valset=valset,
max_bootstrapped_demos=5,
max_labeled_demos=2,
num_batches=12, # MINIBATCH_SIZE = 25,
max_bootstrapped_demos=0, # 2,
max_labeled_demos=0, # 2,
num_batches=12, # MINIBATCH_SIZE = 25, (eval on trainset)
seed=2025,
requires_permission_to_run=False,
)
Expand All @@ -136,7 +159,7 @@ def train_MIPROv2(trainset, valset, save_path, filename):
# return compiled_baleen


def validate(devset, compiled_baleen, uncompiled_baleen):
def validate(devset, compiled_baleen=None, uncompiled_baleen=None):
from dspy.evaluate.evaluate import Evaluate

# Define metric to check if we retrieved the correct documents
Expand All @@ -155,10 +178,13 @@ def validate(devset, compiled_baleen, uncompiled_baleen):
display_table=5,
# metric=validate_answer,
)
uncompiled_baleen_answer_score = evaluate_on_hotpotqa(
uncompiled_baleen, metric=validate_answer, display_progress=True
)
print(f"## Answer Score for uncompiled Baleen: {uncompiled_baleen_answer_score}")
if uncompiled_baleen is not None:
uncompiled_baleen_answer_score = evaluate_on_hotpotqa(
uncompiled_baleen, metric=validate_answer, display_progress=True
)
print(
f"## Answer Score for uncompiled Baleen: {uncompiled_baleen_answer_score}"
)

if compiled_baleen is None:
return
Expand All @@ -167,12 +193,13 @@ def validate(devset, compiled_baleen, uncompiled_baleen):
compiled_baleen, metric=validate_answer, display_progress=True
)
print(f"## Answer Score for compiled Baleen: {compiled_baleen_answer_score}")
return compiled_baleen_answer_score


def train_and_validate():
import os

save_path = "benchmarks/hotpot_qa/dspy_multi_hop_rag"
save_path = "benchmarks/hotpot_qa/dspy_multi_hop_rag_zero_shot"
if not os.path.exists(save_path):
os.makedirs(save_path)

Expand All @@ -184,6 +211,8 @@ def train_and_validate():
val_accs = []
test_accs = []
training_times = []
max_val_score = 0
max_test_score = 0

num_runs = 4

Expand All @@ -194,8 +223,8 @@ def train_and_validate():
compiled_count = train_MIPROv2(
dspy_trainset, dspy_valset, save_path, output_file
)
val_acc = validate(dspy_valset, compiled_count)
test_acc = validate(dspy_testset, compiled_count)
val_acc = validate(dspy_valset, compiled_count) # 46
test_acc = validate(dspy_testset, compiled_count) # 52

val_accs.append(val_acc)
test_accs.append(test_acc)
Expand All @@ -208,11 +237,14 @@ def train_and_validate():
val_accs = np.array(val_accs)
test_accs = np.array(test_accs)
training_times = np.array(training_times)

max_val_score = val_accs.max()
max_test_score = test_accs.max()
print("Validation accuracy:", val_accs.mean(), val_accs.std())
print("Test accuracy:", test_accs.mean(), test_accs.std())

print("Training time:", training_times.mean())
print("Max val score: ", max_val_score)
print("Max test score: ", max_test_score)


if __name__ == "__main__":
Expand Down Expand Up @@ -246,6 +278,14 @@ def train_and_validate():
# compiled_baleen = train(trainset, dspy_save_path, "hotpotqa.json")
# validate(devset, compiled_baleen, uncompiled_baleen)

# dspy 16 raw shots, 4 demos
# dspy supports multiple generators, in this case 3. Two query generator and one answer generator, they all choose the same examples.
# accuracy 62.0
# with demos (2, 2)
# Validation accuracy: 47.25 3.031088913245535
# Test accuracy: 50.625 3.0898017735770686
# Training time: 2465.3250265717506

# zero shot
# Validation accuracy: 35.5 4.330127018922194
# Test accuracy: 37.875 5.140221298738022
# Training time: 182.31551551818848
# Max val score: 42.0
# Max test score: 46.5
Loading

0 comments on commit 8026d52

Please sign in to comment.