Skip to content

Commit

Permalink
Optional AnswerSetting.max_answer_attempts to allow a new unsure br…
Browse files Browse the repository at this point in the history
…anch (#673)
  • Loading branch information
jamesbraza authored Nov 8, 2024
1 parent 79ee951 commit 405f885
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 18 deletions.
17 changes: 15 additions & 2 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,22 @@ async def reset(self) -> tuple[list[Message], list[Tool]]:
def export_frame(self) -> Frame:
return Frame(state=self.state, info={"query": self._query})

def _has_excess_answer_failures(self) -> bool:
if self._query.settings.answer.max_answer_attempts is None:
return False
return (
sum(
tn == GenerateAnswer.gen_answer.__name__
for s in self.state.tool_history
for tn in s
)
> self._query.settings.answer.max_answer_attempts
)

async def step(
self, action: ToolRequestMessage
) -> tuple[list[Message], float, bool, bool]:
self.state.session.add_tokens(action) # Add usage for action if present
self.state.record_action(action)

# If the action has empty tool_calls, the agent can later take that into account
msgs = cast(
Expand All @@ -175,7 +187,8 @@ async def step(
and msg.name == GenerateAnswer.gen_answer.__name__
and GenerateAnswer.did_not_fail_to_answer(msg.content)
for msg in msgs
),
)
or self._has_excess_answer_failures(),
False,
)

Expand Down
36 changes: 21 additions & 15 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,30 @@ async def step(
messages, reward, done, truncated = await super().step(action)
if not done or not self._evaluation_from_answer:
return messages, reward, done, truncated
# Filter out non-answer messages (in case parallel tool calls)
answer_tool_messages = [
m
for m in messages
if isinstance(m, ToolResponseMessage)
and m.name == GenerateAnswer.gen_answer.__name__
]
if not answer_tool_messages: # No answer, so no positive reward
valid_answers, failed_answer_messages = [], []
for m in messages:
if (
not isinstance(m, ToolResponseMessage)
or m.name != GenerateAnswer.gen_answer.__name__
):
continue # Filter out non-answer messages (in case parallel tool calls)
if answer := GenerateAnswer.extract_answer_from_message(content=m.content):
valid_answers.append(answer)
else:
failed_answer_messages.append(m)
if not valid_answers: # No answer, so no positive reward
return messages, reward, done, truncated
if len(answer_tool_messages) != 1:
if len(valid_answers) != 1:
raise NotImplementedError(
f"Expected just one answer message, got {messages}."
f"Expected just one answer message, got more than one in {messages}."
)
answer = GenerateAnswer.extract_answer_from_message(
content=answer_tool_messages[0].content
)
if not answer:
return messages, reward, done, truncated
answer = valid_answers[0]
if failed_answer_messages:
logger.warning(
"More than one answer detected, discarding failed answer messages"
f" {failed_answer_messages}, continuing with answer {answer}."
)
# Okay, so we have one answer that was not a failed answer. Let's evaluate it
evaluation = await self._evaluation_from_answer(answer)
if evaluation_callback := self._evaluation_callback:
await evaluation_callback(evaluation)
Expand Down
13 changes: 13 additions & 0 deletions paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from typing import ClassVar, cast

from aviary.core import ToolRequestMessage
from pydantic import BaseModel, ConfigDict, Field, computed_field

from paperqa.docs import Docs
Expand Down Expand Up @@ -36,6 +37,14 @@ class EnvironmentState(BaseModel):

docs: Docs
session: PQASession = Field(..., alias="answer")
tool_history: list[list[str]] = Field(
default_factory=list,
description=(
"History of tool names input to each Environment.step (regardless of being"
" a typo or not), where the outer list is steps, and the inner list matches"
" the order of tool calls at each step."
),
)

# SEE: https://regex101.com/r/RmuVdC/1
STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = (
Expand Down Expand Up @@ -65,6 +74,10 @@ def status(self) -> str:
cost=self.session.cost,
)

def record_action(self, action: ToolRequestMessage) -> None:
self.session.add_tokens(action)
self.tool_history.append([tc.function.name for tc in action.tool_calls])


class NamedTool(BaseModel):
"""Base class to make looking up tools easier."""
Expand Down
7 changes: 7 additions & 0 deletions paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class AnswerSettings(BaseModel):
answer_max_sources: int = Field(
default=5, description="Max number of sources to use for an answer"
)
max_answer_attempts: int | None = Field(
default=None,
description=(
"Optional (exclusive) max number (default is no max) of attempts to"
" generate an answer before declaring a failure."
),
)
answer_length: str = Field(
"about 200 words, but can be longer", description="Length of final answer"
)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,21 @@ async def test_evaluation(
base_query_request.settings.agent.tool_names = {
GenerateAnswer.gen_answer.__name__
}
base_query_request.settings.answer.max_answer_attempts = 2
base_query_request.settings.answer.get_evidence_if_no_contexts = False
dataset = LitQAv2TaskDataset(base_query=base_query_request)
dataset.data = dataset.data[:2] # Save the world: just use two questions
storage_callback = StoreTrajectoriesCallback()
evaluator = Evaluator(
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=2),
config=EvaluatorConfig(batch_size=len(dataset), max_rollout_steps=4),
agent=SimpleAgent(),
dataset=dataset,
callbacks=[storage_callback],
)
await evaluator.evaluate()
for traj in storage_callback.eval_trajectories:
assert not traj.failed
assert traj.done
for step in traj.steps:
assert all(
tc.function.name == GenerateAnswer.gen_answer.__name__
Expand Down

0 comments on commit 405f885

Please sign in to comment.