Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 123 additions & 6 deletions tests/rl/agentic/agentic_grpo_learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,27 +562,125 @@ def __call__(self, inputs, positions, cache, attention_mask):

np.testing.assert_allclose(loss, expected_loss, rtol=1e-6, atol=1e-6)

def test_process_results_extracts_assistant_text(self):
class MockTraj:
def __init__(self, index):
self.traj = {
"conversation_text": [
{"role": "system", "content": "system prompt"},
{"role": "user", "content": "user query"},
{"role": "assistant", "content": f"msg {index}"},
],
"conversation_tokens": np.array([1, 2, 3]),
"conversation_masks": np.array([1, 1, 1]),
"old_logprobs": None,
"policy_version": 0,
"trajectory_reward": 1.0,
"prompt_tokens": np.array([4, 5]),
"original_input": {"prompts": "hello"},
"group_id": "group1",
}

trajectories = [MockTraj(0), MockTraj(1)]

extracted_completions = []
def mock_compute_rewards(prompts, completions, **kwargs):
extracted_completions.extend(completions)
return jnp.ones(len(completions), dtype=jnp.float32)

vocab = _mock_vocab()
tokenizer = tokenizer_adapter.TokenizerAdapter(vocab)
model = test_common.ToyTransformer(
config=test_common.ModelConfig(vocab_size=vocab.GetPieceSize()),
rngs=nnx.Rngs(0),
)
ref_model = test_common.ToyTransformer(
config=test_common.ModelConfig(vocab_size=vocab.GetPieceSize()),
rngs=nnx.Rngs(0),
)
mesh = pxla.thread_resources.env.physical_mesh
cluster_config = rl_cluster_lib.ClusterConfig(
role_to_mesh={
rl_cluster_lib.Role.ACTOR: mesh,
rl_cluster_lib.Role.REFERENCE: mesh,
rl_cluster_lib.Role.ROLLOUT: mesh,
},
rollout_engine="vanilla",
training_config=rl_cluster_lib.RLTrainingConfig(
actor_optimizer=optax.sgd(1e-3),
eval_every_n_steps=100,
),
rollout_config=base_rollout.RolloutConfig(
max_prompt_length=32,
max_tokens_to_generate=10,
return_logprobs=True,
),
)
rl_cluster = rl_cluster_lib.RLCluster(
actor=model,
reference=ref_model,
tokenizer=tokenizer,
cluster_config=cluster_config,
)
grpo_config = agentic_grpo_learner.GRPOConfig(
beta=0.1,
epsilon=0.2,
num_generations=2,
loss_algo="grpo",
max_response_length=10,
)

learner = agentic_grpo_learner.GRPOLearner(
rl_cluster=rl_cluster,
reward_fns=None,
algo_config=grpo_config,
chat_parser=MockChatParser(),
)

with mock.patch.object(learner, "_compute_rewards", side_effect=mock_compute_rewards):
with mock.patch.object(
learner.rl_cluster,
"get_ref_per_token_logps",
return_value=jnp.zeros((2, 10)),
autospec=True,
):
learner._process_results(trajectories)

self.assertEqual(extracted_completions, ["msg 0", "msg 1"])

@parameterized.named_parameters(
dict(testcase_name="masking_disabled", masking=False),
dict(testcase_name="masking_enabled", masking=True),
)
def test_process_results_masks_zero_advantage_group(self, masking):
class MockTraj:

def __init__(self, index, group_id, reward):
def __init__(
self,
index,
group_id,
reward,
has_assistant_message=True,
old_logprobs=None,
):
self.traj = {
"conversation_text": [
{"role": "assistant", "content": f"msg {index}"}
],
"conversation_text": [],
"conversation_tokens": np.array([1, 2, 3]),
"conversation_masks": np.array([1, 1, 1]),
"old_logprobs": None,
"old_logprobs": old_logprobs,
"policy_version": 0,
"trajectory_reward": reward,
"prompt_tokens": np.array([4, 5]),
"original_input": {"prompts": "hello"},
"group_id": group_id,
}
self.traj["conversation_text"].append(
{"role": "user", "content": "user message"}
)
if has_assistant_message:
self.traj["conversation_text"].append(
{"role": "assistant", "content": f"msg {index}"}
)

# Group 1: non-degenerate (different rewards)
group1 = [MockTraj(0, "group1", -1.0), MockTraj(1, "group1", 1.0)]
Expand Down Expand Up @@ -645,7 +743,7 @@ def __init__(self, index, group_id, reward):
"get_ref_per_token_logps",
return_value=jnp.zeros((2, 10)),
autospec=True,
):
) as mock_get_ref:
[res_group1] = learner._process_results(group1)
[res_group2] = learner._process_results(group2)

Expand All @@ -660,6 +758,25 @@ def __init__(self, index, group_id, reward):
# Masking disabled: degenerate group should remain intact
self.assertTrue(jnp.any(res_group2.completion_mask > 0))

# Test group with missing assistant message
group3 = [
MockTraj(4, "group3", 0.0, has_assistant_message=False),
MockTraj(5, "group3", 0.0),
]
[res_group3] = learner._process_results(group3)
if masking:
self.assertFalse(jnp.any(res_group3.completion_mask > 0))
else:
self.assertTrue(jnp.any(res_group3.completion_mask > 0))

# Test group with partially missing old_logprobs
group4 = [
MockTraj(6, "group4", 0.0, old_logprobs=np.array([-0.1, -0.2, -0.3])),
MockTraj(7, "group4", 0.0, old_logprobs=None),
]
[res_group4] = learner._process_results(group4)
self.assertIsNone(res_group4.old_per_token_logps)

def test_checkpointing(self):
ckpt_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, ckpt_dir)
Expand Down
24 changes: 24 additions & 0 deletions tests/rl/agentic/trajectory/trajectory_collect_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,30 @@ def test_collect_with_tokenization(self, mock_convert):
mock_convert.call_args_list[2].kwargs['contains_generation_msg']
)

@mock.patch.object(utils, 'tokenize_and_generate_masks')
def test_collect_token_mode_empty_steps(self, mock_convert):
mock_convert.side_effect = [
([101], [1]), # prompt tokens
]
self.mock_env.max_steps = 0 # No steps will be taken
engine = trajectory_collect_engine.TrajectoryCollectEngine(
agent=self.mock_agent,
env=self.mock_env,
model_call=self.mock_model_call,
tokenizer=self.mock_tokenizer,
chat_parser=self.mock_chat_parser,
max_context_limit=1024,
)
token_data = asyncio.run(self._run_collect(engine, mode='Token'))
self.assertEmpty(self.mock_agent.trajectory.steps)
np.testing.assert_array_equal(
token_data['conversation_tokens'], np.array([], dtype=np.int32)
)
np.testing.assert_array_equal(
token_data['conversation_masks'], np.array([], dtype=np.int32)
)
self.assertIsNone(token_data['old_logprobs'])

@mock.patch.object(utils, 'tokenize_and_generate_masks')
def test_collect_with_incomplete_tokenizer_config_skips_tokenization(
self, mock_tokenize
Expand Down
13 changes: 9 additions & 4 deletions tunix/rl/agentic/agentic_grpo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,12 @@ def _process_results(
trajectories_to_log.append(item.traj)
conversation = item.traj.get("conversation_text") or []
assistant_text = next(
message["content"]
for message in conversation
if message["role"] == "assistant"
(
message["content"]
for message in conversation
if message["role"] == "assistant"
),
"",
)

completion_texts.append(assistant_text)
Expand Down Expand Up @@ -383,7 +386,9 @@ def _process_results(
completion_ids.shape,
)

if padded_old_logprobs:
if padded_old_logprobs and len(padded_old_logprobs) == len(
completion_tokens_list
):
old_per_token_logps = jnp.asarray(padded_old_logprobs)
else:
old_per_token_logps = None
Expand Down
33 changes: 28 additions & 5 deletions tunix/rl/agentic/trajectory/trajectory_collect_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,30 @@ async def collect(self, mode: str = "Conversation") -> Any:
logprobs.append(step.logprobs)
if getattr(step, "env_tokens", None) is not None:
logprobs.append(np.zeros(len(step.env_tokens)))
conversation_masks = np.concatenate(conversation_masks, axis=0)

conversation_tokens = [
np.asarray(tokens)
for tokens in conversation_tokens
if len(tokens) > 0
]
conversation_masks = [
np.asarray(masks) for masks in conversation_masks if len(masks) > 0
]
logprobs = [
np.asarray(step_logprobs)
for step_logprobs in logprobs
if len(step_logprobs) > 0
]
conversation_masks = (
np.concatenate(conversation_masks, axis=0)
if conversation_masks
else np.array([], dtype=np.int32)
)
conversation_tokens = (
np.concatenate(conversation_tokens, axis=0)
if conversation_tokens
else np.array([], dtype=np.int32)
)
final_masks = (
np.zeros_like(conversation_masks)
if masked_out
Expand All @@ -298,7 +321,7 @@ async def collect(self, mode: str = "Conversation") -> Any:
return {
"conversation_text": self.agent.chat_completions,
"prompt_tokens": prompt_tokens,
"conversation_tokens": np.concatenate(conversation_tokens, axis=0),
"conversation_tokens": conversation_tokens,
"conversation_masks": final_masks,
"status": self.agent.trajectory.status.name,
"trajectory_reward": self.agent.trajectory.reward,
Expand Down Expand Up @@ -386,9 +409,9 @@ async def _reset(self):
self.env_time["reset_latency"] += wall_time
self.env_time["reset_cpu_time"] += cpu_time
self.final_reward_fn = (
self.env.final_reward_fn
if hasattr(self.env, "final_reward_fn")
else None
self.env.final_reward_fn
if hasattr(self.env, "final_reward_fn")
else None
)
self.agent.reset()
self.agent.update_from_env(observation=obs, reward=0.0, done=False, info={})
Expand Down
Loading