-
Notifications
You must be signed in to change notification settings - Fork 6k
/
Copy pathtest_action.py
127 lines (95 loc) · 4.01 KB
/
test_action.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gym
import numpy as np
import unittest
from ray.rllib.connectors.action.clip import ClipActionsConnector
from ray.rllib.connectors.action.lambdas import (
ConvertToNumpyConnector,
UnbatchActionsConnector,
)
from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
from ray.rllib.connectors.connector import (
ConnectorContext,
get_connector,
)
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ActionConnectorDataType
torch, _ = try_import_torch()
class TestActionConnector(unittest.TestCase):
def test_connector_pipeline(self):
ctx = ConnectorContext()
connectors = [ConvertToNumpyConnector(ctx)]
pipeline = ActionConnectorPipeline(ctx, connectors)
name, params = pipeline.to_config()
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ActionConnectorPipeline))
self.assertTrue(isinstance(restored.connectors[0], ConvertToNumpyConnector))
def test_convert_to_numpy_connector(self):
ctx = ConnectorContext()
c = ConvertToNumpyConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ConvertToNumpyConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ConvertToNumpyConnector))
action = torch.Tensor([8, 9])
states = torch.Tensor([[1, 1, 1], [2, 2, 2]])
ac_data = ActionConnectorDataType(0, 1, (action, states, {}))
converted = c(ac_data)
self.assertTrue(isinstance(converted.output[0], np.ndarray))
self.assertTrue(isinstance(converted.output[1], np.ndarray))
def test_unbatch_action_connector(self):
ctx = ConnectorContext()
c = UnbatchActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "UnbatchActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, UnbatchActionsConnector))
ac_data = ActionConnectorDataType(
0,
1,
(
{
"a": np.array([1, 2, 3]),
"b": (np.array([4, 5, 6]), np.array([7, 8, 9])),
},
[],
{},
),
)
unbatched = c(ac_data)
actions, _, _ = unbatched.output
self.assertEqual(len(actions), 3)
self.assertEqual(actions[0]["a"], 1)
self.assertTrue((actions[0]["b"] == np.array((4, 7))).all())
self.assertEqual(actions[1]["a"], 2)
self.assertTrue((actions[1]["b"] == np.array((5, 8))).all())
self.assertEqual(actions[2]["a"], 3)
self.assertTrue((actions[2]["b"] == np.array((6, 9))).all())
def test_normalize_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = NormalizeActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "NormalizeActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, NormalizeActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (0.5, [], {}))
normalized = c(ac_data)
self.assertEqual(normalized.output[0], 4.5)
def test_clip_action_connector(self):
ctx = ConnectorContext(
action_space=gym.spaces.Box(low=0.0, high=6.0, shape=[1])
)
c = ClipActionsConnector(ctx)
name, params = c.to_config()
self.assertEqual(name, "ClipActionsConnector")
restored = get_connector(ctx, name, params)
self.assertTrue(isinstance(restored, ClipActionsConnector))
ac_data = ActionConnectorDataType(0, 1, (8.8, [], {}))
clipped = c(ac_data)
self.assertEqual(clipped.output[0], 6.0)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))