|  | 
| 15 | 15 | from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \ | 
| 16 | 16 |     MockPolicyGraph | 
| 17 | 17 | from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator | 
|  | 18 | +from ray.rllib.evaluation.policy_graph import PolicyGraph | 
| 18 | 19 | from ray.rllib.evaluation.metrics import collect_metrics | 
| 19 | 20 | from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync | 
| 20 | 21 | from ray.rllib.env.multi_agent_env import MultiAgentEnv | 
| @@ -306,6 +307,31 @@ def testMultiAgentSampleRoundRobin(self): | 
| 306 | 307 |         self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10], | 
| 307 | 308 |                          [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) | 
| 308 | 309 | 
 | 
|  | 310 | +    def testCustomRNNStateValues(self): | 
|  | 311 | +        h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} | 
|  | 312 | + | 
|  | 313 | +        class StatefulPolicyGraph(PolicyGraph): | 
|  | 314 | +            def compute_actions(self, | 
|  | 315 | +                                obs_batch, | 
|  | 316 | +                                state_batches, | 
|  | 317 | +                                is_training=False, | 
|  | 318 | +                                episodes=None): | 
|  | 319 | +                return [0] * len(obs_batch), [[h] * len(obs_batch)], {} | 
|  | 320 | + | 
|  | 321 | +            def get_initial_state(self): | 
|  | 322 | +                return [{}]  # empty dict | 
|  | 323 | + | 
|  | 324 | +        ev = PolicyEvaluator( | 
|  | 325 | +            env_creator=lambda _: gym.make("CartPole-v0"), | 
|  | 326 | +            policy_graph=StatefulPolicyGraph, | 
|  | 327 | +            batch_steps=5) | 
|  | 328 | +        batch = ev.sample() | 
|  | 329 | +        self.assertEqual(batch.count, 5) | 
|  | 330 | +        self.assertEqual(batch["state_in_0"][0], {}) | 
|  | 331 | +        self.assertEqual(batch["state_out_0"][0], h) | 
|  | 332 | +        self.assertEqual(batch["state_in_0"][1], h) | 
|  | 333 | +        self.assertEqual(batch["state_out_0"][1], h) | 
|  | 334 | + | 
| 309 | 335 |     def testReturningModelBasedRolloutsData(self): | 
| 310 | 336 |         class ModelBasedPolicyGraph(PGPolicyGraph): | 
| 311 | 337 |             def compute_actions(self, | 
|  | 
0 commit comments