Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions rllm/workflows/single_turn_workflow.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,47 @@
from rllm.agents.agent import Episode
from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow
from rllm.agents.agent import Action, Episode, Step, Trajectory
from rllm.engine import RolloutEngine
from rllm.rewards.reward_fn import RewardFunction
from rllm.workflows.workflow import Workflow


class SingleTurnWorkflow(Workflow):
def __init__(
self,
agent_cls,
env_cls,
agent_args=None,
env_args=None,
**kwargs,
):
super().__init__(**kwargs)

# Initialize mutable defaults
agent_args = dict(agent_args) if agent_args is not None else {}
env_args = dict(env_args) if env_args is not None else {}

self.agent = agent_cls(**agent_args)
self.register_agent(self.agent)
self.env = env_cls(**env_args)

async def run(self, task: dict, uid: str, **kwargs) -> Episode | None:
observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) # returns observation and info from the environment
self.agent.update_from_env(observation, 0, False, info)

response = (await self.get_model_response(self.agent, **kwargs)).text
action = self.agent.update_from_model(response)

next_obs, reward, done, info = await self.run_in_executor(self.env.step, action)
self.agent.update_from_env(next_obs, reward, done, info)

if self._termination_buffer is not None:
raise TerminationEvent(self._termination_buffer)

raise TerminationReason.ENV_DONE
def __init__(self, rollout_engine: RolloutEngine, reward_function: RewardFunction = None, **kwargs):
super().__init__(rollout_engine, **kwargs)

self.reward_function = reward_function

async def run(self, task: dict, uid: str, **kwargs) -> Episode:
"""Execute the single agent workflow."""
# Reset components for new task
self.reset(task, uid)

messages = task["messages"]
response = await self.rollout_engine.get_model_response(messages)
reward_result = self.reward_function(task, response)

trajectory = Trajectory()
trajectory.steps.append(
Step(
model_response=response.text,
action=Action(response.text),
chat_completions=messages + [{"role": "assistant", "content": response.text}],
reward=reward_result.reward,
)
)

is_correct = reward_result.is_correct
# Create episode with trajectories as list of tuples
episode = Episode(
id=uid,
task=task,
is_correct=is_correct,
trajectories=[("single_agent", trajectory)],
metrics={},
)

return episode

def reset(self, task: dict, uid: str):
self.messages = []
self.task = task
self.uid = uid