-
Notifications
You must be signed in to change notification settings - Fork 6k
/
Copy pathtest_connector.py
59 lines (44 loc) · 1.9 KB
/
test_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import unittest
from ray.rllib.connectors.connector import Connector, ConnectorPipeline
class TestConnectorPipeline(unittest.TestCase):
class Tom(Connector):
def to_config():
return "tom"
class Bob(Connector):
def to_config():
return "bob"
class Mary(Connector):
def to_config():
return "mary"
class MockConnectorPipeline(ConnectorPipeline):
def __init__(self, ctx, connectors):
# Real connector pipelines should keep a list of
# Connectors.
# Use strings here for simple unit tests.
self.connectors = connectors
def test_sanity_check(self):
ctx = {}
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.insert_before("Bob", self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.insert_after("Tom", self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.prepend(self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[0].__class__.__name__, "Mary")
m = self.MockConnectorPipeline(ctx, [self.Tom(ctx), self.Bob(ctx)])
m.append(self.Mary(ctx))
self.assertEqual(len(m.connectors), 3)
self.assertEqual(m.connectors[2].__class__.__name__, "Mary")
m.remove("Bob")
self.assertEqual(len(m.connectors), 2)
self.assertEqual(m.connectors[0].__class__.__name__, "Tom")
self.assertEqual(m.connectors[1].__class__.__name__, "Mary")
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))