From b37a2830534fd51cdc28bbdd04c8849ac0809234 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 2 Sep 2018 23:02:19 -0700 Subject: [PATCH] [rllib] support local mode (#2795) --- python/ray/experimental/internal_kv.py | 11 +++++++++++ python/ray/rllib/test/test_local.py | 20 ++++++++++++++++++++ test/jenkins_tests/run_multi_node_tests.sh | 3 +++ 3 files changed, 34 insertions(+) create mode 100644 python/ray/rllib/test/test_local.py diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index 85476a7c3768..f2a19450c202 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -4,6 +4,8 @@ import ray +_local = {} # dict for local mode + def _internal_kv_initialized(): worker = ray.worker.get_global_worker() @@ -14,6 +16,9 @@ def _internal_kv_get(key): """Fetch the value of a binary key.""" worker = ray.worker.get_global_worker() + if worker.mode == ray.worker.LOCAL_MODE: + return _local.get(key) + return worker.redis_client.hget(key, "value") @@ -27,6 +32,12 @@ def _internal_kv_put(key, value, overwrite=False): """ worker = ray.worker.get_global_worker() + if worker.mode == ray.worker.LOCAL_MODE: + exists = key in _local + if not exists or overwrite: + _local[key] = value + return exists + if overwrite: updated = worker.redis_client.hset(key, "value", value) else: diff --git a/python/ray/rllib/test/test_local.py b/python/ray/rllib/test/test_local.py new file mode 100644 index 000000000000..2f76de3741dc --- /dev/null +++ b/python/ray/rllib/test/test_local.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from ray.rllib.agents.ppo import PPOAgent, DEFAULT_CONFIG +import ray + + +class LocalModeTest(unittest.TestCase): + def testLocal(self): + ray.init(local_mode=True) + cf = DEFAULT_CONFIG.copy() + agent = PPOAgent(cf, "CartPole-v0") + print(agent.train()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index 24f4830973a4..b6e6c5105b7b 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -216,6 +216,9 @@ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh +docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/test/test_local.py + docker run -e "RAY_USE_XRAY=1" --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/test/test_checkpoint_restore.py