forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
84 lines (67 loc) · 2.61 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import sys
# Note: do not introduce unnecessary library dependencies here, e.g. gym.
# This file is imported from the tune module in order to register RLlib agents.
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.tune.registry import register_trainable
def _setup_logger():
logger = logging.getLogger("ray.rllib")
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
))
logger.addHandler(handler)
logger.propagate = False
if sys.version_info[0] < 3:
logger.warn(
"RLlib Python 2 support is deprecated, and will be removed "
"in a future release.")
def _register_all():
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.registry import ALGORITHMS, get_agent_class
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
)) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
register_trainable(key, get_agent_class(key))
def _see_contrib(name):
"""Returns dummy agent class warning algo is in contrib/."""
class _SeeContrib(Trainer):
_name = "SeeContrib"
_default_config = with_common_config({})
def _setup(self, config):
raise NameError(
"Please run `contrib/{}` instead.".format(name))
return _SeeContrib
# also register the aliases minus contrib/ to give a good error message
for key in list(CONTRIBUTED_ALGORITHMS.keys()):
assert key.startswith("contrib/")
alias = key.split("/", 1)[1]
register_trainable(alias, _see_contrib(alias))
_setup_logger()
_register_all()
__all__ = [
"Policy",
"PolicyGraph",
"TFPolicy",
"TFPolicyGraph",
"RolloutWorker",
"PolicyEvaluator",
"SampleBatch",
"BaseEnv",
"MultiAgentEnv",
"VectorEnv",
"ExternalEnv",
]