Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions examples/deepswe/train_deepswe_nb.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
parser.add_argument("--max_turns", type=int, default=20)
parser.add_argument("--per_turn_timeout_secs", type=int, default=300)
parser.add_argument("--max_concurrency", type=int, default=1)
parser.add_argument("--context_ratio", type=int, default=2)


# Other
Expand Down Expand Up @@ -294,10 +293,7 @@
PER_TURN_TIMEOUT_SECS = args.per_turn_timeout_secs

MAX_CONCURRENCY = args.max_concurrency
CONTEXT_RATIO = args.context_ratio # Context length can be up to 2x responselength in DeepSWE due to multi-turn interactions and long responses, so we set context ratio to 2 to accommodate this.
KV_CACHE_SIZE = MAX_PROMPT_LENGTH + (
MAX_RESPONSE_LENGTH * CONTEXT_RATIO * MAX_TURNS
)
KV_CACHE_SIZE = MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH + 128
print(f"kv_cache_size (Capped): {KV_CACHE_SIZE}")
# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = args.learning_rate
Expand All @@ -316,8 +312,6 @@
ROLLOUT_ENGINE = args.rollout_engine
CKPT_DIR = args.ckpt_dir

# Max number of sequences to be processed in parallel by vllm.
VLLM_MAX_NUM_SEQS = ROLLOUT_MICRO_BATCH_SIZE * NUM_GENERATIONS

VLLM_UTILIZATION = args.vllm_utilization

Expand All @@ -327,7 +321,7 @@
# Max number of tokens to be processed in parallel by vllm.
# Divide by 8 for on policy, 1 step off divide by 4

VLLM_MAX_BATCHED_TOKENS = (VLLM_MAX_NUM_SEQS * KV_CACHE_SIZE) // 4
VLLM_MAX_BATCHED_TOKENS = (VLLM_MAX_NUM_SEQS * KV_CACHE_SIZE) // 8
print(f"vllm_max_batched_tokens: {VLLM_MAX_BATCHED_TOKENS}")
# %%
# ==========================================
Expand Down
7 changes: 6 additions & 1 deletion tunix/rl/agentic/agentic_rl_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import itertools
import queue
import threading
from typing import Any, AsyncIterator, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, Type, TypeVar
from typing import Any, AsyncIterator, Callable, Dict, Generic, Iterable, Iterator, List, Sequence, Type, TypeVar, Optional, Set

from absl import logging
import flax
Expand Down Expand Up @@ -78,6 +78,8 @@ class AgenticRLConfig(algo_config_lib.AlgorithmConfig):
num_generations: Number of samples per prompt.
num_iterations: Number of iterations per batch.
episode_timeout: Timeout for each episode in seconds.
filter_statuses: Set of trajectory statuses to filter out.
overlong_filter: Whether to filter out overlong trajectories.
"""

system_prompt: str = ""
Expand All @@ -90,6 +92,8 @@ class AgenticRLConfig(algo_config_lib.AlgorithmConfig):
num_generations: int = 1
num_iterations: int = 1
episode_timeout: float = 1800.0
filter_statuses: Optional[Set] = None
overlong_filter: bool = False


TConfig = TypeVar("TConfig", bound=AgenticRLConfig)
Expand Down Expand Up @@ -425,6 +429,7 @@ def _build_orchestrator(self) -> rollout_orchestrator.RolloutOrchestrator:
tokenizer=self.tokenizer,
chat_parser=self.chat_parser,
timeout=self.algo_config.episode_timeout,
max_response_length=self.algo_config.max_response_length,
perf_v2=self.rl_cluster.perf_v2,
)
return rollout_orchestrator.RolloutOrchestrator(
Expand Down
69 changes: 34 additions & 35 deletions tunix/rl/agentic/trajectory/trajectory_collect_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
],
model_call_kwargs: Optional[Dict[str, Any]] = None,
gamma: float = 1.0,
max_context_limit: Optional[int] = None,
max_response_length: Optional[int] = None,
timeout: float = 600.0,
tokenizer=None,
chat_parser=None,
Expand All @@ -87,7 +87,7 @@ def __init__(
float. Defaults to zero if not provided.
gamma (float): Discount factor for MC reward calculation (1.0 = no
discounting).
max_context_limit (Optional[int]): Maximum number of context tokens to
max_response_length (Optional[int]): Maximum number of context tokens to
use before forced termination.
timeout (float): Maximum episode duration in seconds before timeout
termination
Expand All @@ -107,7 +107,7 @@ def __init__(
self.model_call_kwargs = model_call_kwargs or {}
self.max_steps = getattr(self.env, "max_steps", 1)
self.gamma = gamma
self.max_context_limit = max_context_limit
self.max_response_length = max_response_length
self.timeout = timeout

# Tokenizer utilities for stepwise tokenization
Expand Down Expand Up @@ -140,11 +140,13 @@ def __init__(
), # Thread/CPU time (Actual processing time on the worker thread)
}

if self.max_context_limit and not (self.tokenizer and self.chat_parser):
if self.max_response_length is not None and not (
self.tokenizer and self.chat_parser
):
logging.warning(
"max_context_limit is set to %d, but no tokenizer or chat_parser is"
" provided. Context limits will not be enforced.",
self.max_context_limit,
"max_response_length is set to %d, but no tokenizer or chat_parser is"
" provided. response length limits will not be enforced.",
self.max_response_length,
)

async def _run_with_timing(
Expand Down Expand Up @@ -186,14 +188,6 @@ async def collect(self, mode: str = "Conversation") -> Any:
""" # fmt: skip
await self._reset()

# Initial Prompt Cost
current_token_count = 0
if (
hasattr(self.agent.trajectory, "prompt_tokens")
and self.agent.trajectory.prompt_tokens
):
current_token_count += len(self.agent.trajectory.prompt_tokens)

self.agent.trajectory.status = agent_types.TrajectoryStatus.RUNNING

while True:
Expand All @@ -204,22 +198,6 @@ async def collect(self, mode: str = "Conversation") -> Any:
break

done = await self._one_step()
current_step = self.agent.get_current_step()

if current_step:
if getattr(current_step, "assistant_tokens", None) is not None:
current_token_count += len(current_step.assistant_tokens)
if getattr(current_step, "env_tokens", None) is not None:
current_token_count += len(current_step.env_tokens)

if (
self.max_context_limit is not None
and current_token_count >= self.max_context_limit
):
self.agent.trajectory.status = (
agent_types.TrajectoryStatus.MAX_CONTEXT_LIMIT_REACHED
)
break

if done:
if self.agent.trajectory.status == agent_types.TrajectoryStatus.RUNNING:
Expand Down Expand Up @@ -321,8 +299,8 @@ async def collect_multiple(
*,
model_call: Callable[..., base_rollout.RolloutOutput],
gamma: float = 1.0,
max_context_limit: Optional[int] = None,
timeout: float = 30.0,
max_response_length: Optional[int] = None,
mode: str = "Trajectory",
filter_statuses: Optional[Set[agent_types.TrajectoryStatus]] = None,
overlong_filter: bool = True,
Expand All @@ -339,7 +317,7 @@ async def collect_multiple(
environment) pairs
model_call (Callable): Shared model inference function for all pairs
gamma (float): Discount factor for return calculation
max_context_limit (Optional[int]): Maximum context limit per episode
max_response_length (Optional[int]): Maximum context limit per episode
timeout (float): Per-episode timeout in seconds
mode (str): Output format. See `collect` method for options.
filter_statuses (Optional[Set[TrajectoryStatus]]): A set of statuses
Expand All @@ -348,7 +326,6 @@ async def collect_multiple(
perf_v2 (Optional[perf_tracer_v2.Tracer]): Optional performance tracer
to use for performance measurements.


Yields:
Tuple[int, Any]: `(pair_index, result)`. The type of `result`
depends on the `mode` argument. See the `collect` method for details.
Expand All @@ -361,7 +338,7 @@ async def _run_one(i: int, agent: ConversationAgentBase, env: BaseTaskEnv):
env,
model_call=model_call,
gamma=gamma,
max_context_limit=max_context_limit,
max_response_length=max_response_length,
timeout=timeout,
filter_statuses=filter_statuses,
overlong_filter=overlong_filter,
Expand Down Expand Up @@ -406,6 +383,7 @@ async def _reset(self):
self.agent.trajectory.prompt_tokens = prompt_tokens

self._start_ts = time.perf_counter()
self._response_token_count = 0

def _get_perf_tags(self) -> Dict[str, Any]:
"""Extracts performance tracing tags from the environment."""
Expand All @@ -423,6 +401,19 @@ def _get_perf_tags(self) -> Dict[str, Any]:
tags[perf_constants.STEP] = policy_version
return tags

def _check_and_set_context_limit_reached(self) -> bool:
"""Returns True and updates trajectory status if response budget is exhausted."""
if (
self.max_response_length is not None
and self._response_token_count >= self.max_response_length
):
self.agent.trajectory.status = (
agent_types.TrajectoryStatus.MAX_CONTEXT_LIMIT_REACHED
)

return True
return False

async def _one_step(self) -> bool:
"""Executes a single step and returns the Step object and Done status.

Expand All @@ -434,13 +425,20 @@ async def _one_step(self) -> bool:
bool: True if the episode is done (either by environment or timeout),
False otherwise.
"""
if self._check_and_set_context_limit_reached():
return True

rollout_output = await asyncio.get_event_loop().run_in_executor(
None,
self.model_call,
self.agent.chat_completions,
self.env,
**self.model_call_kwargs,
)
if rollout_output.tokens:
self._response_token_count += len(rollout_output.tokens[0])
if self._check_and_set_context_limit_reached():
return True

action = self.agent.update_from_model(rollout_output.text[0]).action

Expand Down Expand Up @@ -489,6 +487,7 @@ async def _one_step(self) -> bool:
)
cur_step.env_tokens = np.array(e_tokens)
cur_step.env_masks = np.array(e_masks)
self._response_token_count += len(e_tokens)

if time.perf_counter() - self._start_ts > self.timeout:
self.agent.trajectory.status = agent_types.TrajectoryStatus.TIMEOUT
Expand Down
Loading