Skip to content

Commit 473ee4e

Browse files
authored
[rllib] Add unit test and some better error messages for custom policy states (#3032)
1 parent 87639b9 commit 473ee4e

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

python/ray/rllib/evaluation/sampler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ def new_episode():
364364
# Record the policy eval results
365365
for policy_id, eval_data in to_eval.items():
366366
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
367+
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
368+
raise ValueError(
369+
"Length of RNN in did not match RNN out, got: "
370+
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
367371
# Add RNN state info
368372
for f_i, column in enumerate(rnn_in_cols[policy_id]):
369373
pi_info_cols["state_in_{}".format(f_i)] = column

python/ray/rllib/evaluation/tf_policy_graph.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,18 @@ def __init__(self,
9595
self._variables = ray.experimental.TensorFlowVariables(
9696
self._loss, self._sess)
9797

98-
assert len(self._state_inputs) == len(self._state_outputs) == \
99-
len(self.get_initial_state()), \
100-
(self._state_inputs, self._state_outputs, self.get_initial_state())
101-
if self._state_inputs:
102-
assert self._seq_lens is not None
98+
if len(self._state_inputs) != len(self._state_outputs):
99+
raise ValueError(
100+
"Number of state input and output tensors must match, got: "
101+
"{} vs {}".format(self._state_inputs, self._state_outputs))
102+
if len(self.get_initial_state()) != len(self._state_inputs):
103+
raise ValueError(
104+
"Length of initial state must match number of state inputs, "
105+
"got: {} vs {}".format(self.get_initial_state(),
106+
self._state_inputs))
107+
if self._state_inputs and self._seq_lens is None:
108+
raise ValueError(
109+
"seq_lens tensor must be given if state inputs are defined")
103110

104111
def build_compute_actions(self,
105112
builder,

python/ray/rllib/test/test_multi_agent_env.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
1616
MockPolicyGraph
1717
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
18+
from ray.rllib.evaluation.policy_graph import PolicyGraph
1819
from ray.rllib.evaluation.metrics import collect_metrics
1920
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
2021
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -306,6 +307,31 @@ def testMultiAgentSampleRoundRobin(self):
306307
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
307308
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
308309

310+
def testCustomRNNStateValues(self):
311+
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
312+
313+
class StatefulPolicyGraph(PolicyGraph):
314+
def compute_actions(self,
315+
obs_batch,
316+
state_batches,
317+
is_training=False,
318+
episodes=None):
319+
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
320+
321+
def get_initial_state(self):
322+
return [{}] # empty dict
323+
324+
ev = PolicyEvaluator(
325+
env_creator=lambda _: gym.make("CartPole-v0"),
326+
policy_graph=StatefulPolicyGraph,
327+
batch_steps=5)
328+
batch = ev.sample()
329+
self.assertEqual(batch.count, 5)
330+
self.assertEqual(batch["state_in_0"][0], {})
331+
self.assertEqual(batch["state_out_0"][0], h)
332+
self.assertEqual(batch["state_in_0"][1], h)
333+
self.assertEqual(batch["state_out_0"][1], h)
334+
309335
def testReturningModelBasedRolloutsData(self):
310336
class ModelBasedPolicyGraph(PGPolicyGraph):
311337
def compute_actions(self,

0 commit comments

Comments
 (0)