Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/ray/rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def new_episode():
# Record the policy eval results
for policy_id, eval_data in to_eval.items():
actions, rnn_out_cols, pi_info_cols = eval_results[policy_id]
if len(rnn_in_cols[policy_id]) != len(rnn_out_cols):
raise ValueError(
"Length of RNN in did not match RNN out, got: "
"{} vs {}".format(rnn_in_cols[policy_id], rnn_out_cols))
# Add RNN state info
for f_i, column in enumerate(rnn_in_cols[policy_id]):
pi_info_cols["state_in_{}".format(f_i)] = column
Expand Down
17 changes: 12 additions & 5 deletions python/ray/rllib/evaluation/tf_policy_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,18 @@ def __init__(self,
self._variables = ray.experimental.TensorFlowVariables(
self._loss, self._sess)

assert len(self._state_inputs) == len(self._state_outputs) == \
len(self.get_initial_state()), \
(self._state_inputs, self._state_outputs, self.get_initial_state())
if self._state_inputs:
assert self._seq_lens is not None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
"{} vs {}".format(self._state_inputs, self._state_outputs))
if len(self.get_initial_state()) != len(self._state_inputs):
raise ValueError(
"Length of initial state must match number of state inputs, "
"got: {} vs {}".format(self.get_initial_state(),
self._state_inputs))
if self._state_inputs and self._seq_lens is None:
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")

def build_compute_actions(self,
builder,
Expand Down
26 changes: 26 additions & 0 deletions python/ray/rllib/test/test_multi_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
MockPolicyGraph
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.env.multi_agent_env import MultiAgentEnv
Expand Down Expand Up @@ -306,6 +307,31 @@ def testMultiAgentSampleRoundRobin(self):
self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
[4, 9, 14, 19, 24, 5, 10, 15, 20, 25])

def testCustomRNNStateValues(self):
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}

class StatefulPolicyGraph(PolicyGraph):
def compute_actions(self,
obs_batch,
state_batches,
is_training=False,
episodes=None):
return [0] * len(obs_batch), [[h] * len(obs_batch)], {}

def get_initial_state(self):
return [{}] # empty dict

ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=StatefulPolicyGraph,
batch_steps=5)
batch = ev.sample()
self.assertEqual(batch.count, 5)
self.assertEqual(batch["state_in_0"][0], {})
self.assertEqual(batch["state_out_0"][0], h)
self.assertEqual(batch["state_in_0"][1], h)
self.assertEqual(batch["state_out_0"][1], h)

def testReturningModelBasedRolloutsData(self):
class ModelBasedPolicyGraph(PGPolicyGraph):
def compute_actions(self,
Expand Down