Skip to content

Commit

Permalink
Shard Redis. (ray-project#539)
Browse files Browse the repository at this point in the history
* Implement sharding in the Ray core

* Single node Python modifications to do sharding

* Do the sharding in redis.cc

* Pipe num_redis_shards through start_ray.py and worker.py.

* Use multiple redis shards in multinode tests.

* first steps for sharding ray.global_state

* Fix problem in multinode docker test.

* fix runtest.py

* fix some tests

* fix redis shard startup

* fix redis sharding

* fix

* fix bug introduced by the map-iterator being consumed

* fix sharding bug

* shard event table

* update number of Redis clients to be 64K

* Fix object table tests by flushing shards in between unit tests

* Fix local scheduler tests

* Documentation

* Register shard locations in the primary shard

* Add plasma unit tests back to build

* lint

* lint and fix build

* Fix

* Address Robert's comments

* Refactor start_ray_processes to start Redis shard

* lint

* Fix global scheduler python tests

* Fix redis module test

* Fix plasma test

* Fix component failure test

* Fix local scheduler test

* Fix runtest.py

* Fix global scheduler test for python3

* Fix task_table_test_and_update bug, from actor task table submission race

* Fix jenkins tests.

* Retry Redis shard connections

* Fix test cases

* Convert database clients to DBClient struct

* Fix race condition when subscribing to db client table

* Remove unused lines, add APITest for sharded Ray

* Fix

* Fix memory leak

* Suppress ReconstructionTests output

* Suppress output for APITestSharded

* Reissue task table add/update commands if initial command does not publish to any subscribers.

* fix

* Fix linting.

* fix tests

* fix linting

* fix python test

* fix linting
  • Loading branch information
stephanie-wang authored and pcmoritz committed May 19, 2017
1 parent 0a43047 commit ee08c82
Show file tree
Hide file tree
Showing 39 changed files with 1,336 additions and 651 deletions.
58 changes: 41 additions & 17 deletions python/ray/common/redis_module/runtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_next_message(pubsub_client, timeout_seconds=10):
class TestGlobalStateStore(unittest.TestCase):

def setUp(self):
redis_port, _ = ray.services.start_redis()
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)

def tearDown(self):
Expand Down Expand Up @@ -308,6 +308,10 @@ def testTaskTableAddAndLookup(self):
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4

# make sure somebody will get a notification (checked in the redis module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))

def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
Expand Down Expand Up @@ -388,33 +392,53 @@ def check_task_reply(message, task_args, updated=False):
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])

def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")

def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)

p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)

p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the acknowledgement message.
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.assertEqual(get_next_message(p)["data"], 2)
self.assertEqual(get_next_message(p)["data"], 3)
# Receive the actual data.
for i in range(3):
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)


if __name__ == "__main__":
Expand Down
111 changes: 85 additions & 26 deletions python/ray/experimental/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import redis

import ray
from ray.utils import (decode, binary_to_object_id, binary_to_hex,
hex_to_binary)

Expand All @@ -25,14 +26,21 @@

# This mapping from integer to task state string must be kept up-to-date with
# the scheduling_state enum in task.h.
task_state_mapping = {
1: "WAITING",
2: "SCHEDULED",
4: "QUEUED",
8: "RUNNING",
16: "DONE",
32: "LOST",
64: "RECONSTRUCTING"
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
TASK_STATUS_RUNNING = 8
TASK_STATUS_DONE = 16
TASK_STATUS_LOST = 32
TASK_STATUS_RECONSTRUCTING = 64
TASK_STATUS_MAPPING = {
TASK_STATUS_WAITING: "WAITING",
TASK_STATUS_SCHEDULED: "SCHEDULED",
TASK_STATUS_QUEUED: "QUEUED",
TASK_STATUS_RUNNING: "RUNNING",
TASK_STATUS_DONE: "DONE",
TASK_STATUS_LOST: "LOST",
TASK_STATUS_RECONSTRUCTING: "RECONSTRUCTING",
}


Expand Down Expand Up @@ -66,8 +74,54 @@ def _initialize_global_state(self, redis_ip_address, redis_port):
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))

ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards, len(ip_address_ports)))

for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))

def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
def _object_table(self, object_id_binary):
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)

def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
Args:
pattern: The KEYS pattern to query.
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result

def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
Args:
Expand All @@ -78,16 +132,18 @@ def _object_table(self, object_id_binary):
A dictionary with information about the object ID in question.
"""
# Return information about a single object ID.
object_locations = self.redis_client.execute_command(
"RAY.OBJECT_TABLE_LOOKUP", object_id_binary)
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None

result_table_response = self.redis_client.execute_command(
"RAY.RESULT_TABLE_LOOKUP", object_id_binary)
result_table_response = self._execute_command(object_id,
"RAY.RESULT_TABLE_LOOKUP",
object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)

Expand All @@ -111,22 +167,21 @@ def object_table(self, object_id=None):
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id.id())
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self.redis_client.keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self.redis_client.keys(
OBJECT_LOCATION_PREFIX + "*")
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
object_id_binary)
binary_to_object_id(object_id_binary))
return results

def _task_table(self, task_id_binary):
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single object task ID.
Args:
Expand All @@ -135,12 +190,15 @@ def _task_table(self, task_id_binary):
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field into a
human-readable string.
"""
task_table_response = self.redis_client.execute_command(
"RAY.TASK_TABLE_GET", task_id_binary)
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id_binary)))
.format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
Expand All @@ -167,7 +225,7 @@ def _task_table(self, task_id_binary):
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}

return {"State": task_state_mapping[task_table_message.State()],
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
Expand All @@ -185,14 +243,15 @@ def task_table(self, task_id=None):
"""
self._check_connected()
if task_id is not None:
return self._task_table(hex_to_binary(task_id))
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self.redis_client.keys(TASK_PREFIX + "*")
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
task_id_binary)
ray.local_scheduler.ObjectID(task_id_binary))
return results

def function_table(self, function_id=None):
Expand Down
6 changes: 3 additions & 3 deletions python/ray/global_scheduler/global_scheduler_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import time


def start_global_scheduler(redis_address, node_ip_address, use_valgrind=False,
use_profiler=False, stdout_file=None,
stderr_file=None):
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
"""Start a global scheduler process.
Args:
Expand Down
Loading

0 comments on commit ee08c82

Please sign in to comment.