Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
Add available resources to global state (ray-project#2501)
Browse files Browse the repository at this point in the history
  • Loading branch information
pschafhalter authored and robertnishihara committed Sep 10, 2018
1 parent 611259b commit 5da6e78
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ matrix:
# - python -m pytest -v python/ray/local_scheduler/test/test.py
# - python -m pytest -v python/ray/global_scheduler/test/test.py

- python -m pytest -v python/ray/test/test_global_state.py
- python -m pytest -v python/ray/test/test_queue.py
- python -m pytest -v test/xray_test.py

Expand Down Expand Up @@ -204,6 +205,7 @@ script:
- python -m pytest -v python/ray/local_scheduler/test/test.py
- python -m pytest -v python/ray/global_scheduler/test/test.py

- python -m pytest -v python/ray/test/test_global_state.py
- python -m pytest -v python/ray/test/test_queue.py
- python -m pytest -v test/xray_test.py

Expand Down
114 changes: 113 additions & 1 deletion python/ray/experimental/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
import heapq
import json
import numbers
import os
import redis
import sys
Expand Down Expand Up @@ -1277,7 +1278,7 @@ def cluster_resources(self):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
resources = defaultdict(lambda: 0)
resources = defaultdict(int)
if not self.use_raylet:
local_schedulers = self.local_schedulers()

Expand All @@ -1297,6 +1298,117 @@ def cluster_resources(self):

return dict(resources)

def available_resources(self):
"""Get the current available cluster resources.
Note that this information can grow stale as tasks start and finish.
Returns:
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
available_resources_by_id = {}

if not self.use_raylet:
subscribe_client = self.redis_client.pubsub()
subscribe_client.subscribe(
ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)

local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}

while set(available_resources_by_id.keys()) != local_scheduler_ids:
raw_message = subscribe_client.get_message()
if raw_message is None:
continue
data = raw_message["data"]
# Ignore subscribtion success message from Redis
# This is a long in python 2 and an int in python 3
if isinstance(data, numbers.Number):
continue
message = (ray.gcs_utils.LocalSchedulerInfoMessage.
GetRootAsLocalSchedulerInfoMessage(data, 0))
num_resources = message.DynamicResourcesLength()
dynamic_resources = {}
for i in range(num_resources):
dyn = message.DynamicResources(i)
resource_id = decode(dyn.Key())
dynamic_resources[resource_id] = dyn.Value()

# Update available resources for this local scheduler
client_id = binary_to_hex(message.DbClientId())
available_resources_by_id[client_id] = dynamic_resources

# Update local schedulers in cluster
local_scheduler_ids = {
local_scheduler["DBClientID"]
for local_scheduler in self.local_schedulers()
}

# Remove disconnected local schedulers
for local_scheduler_id in available_resources_by_id.keys():
if local_scheduler_id not in local_scheduler_ids:
del available_resources_by_id[local_scheduler_id]
else:
# Assumes the number of Redis clients does not change
subscribe_clients = [
redis_client.pubsub(ignore_subscribe_messages=True)
for redis_client in self.redis_clients
]
for subscribe_client in subscribe_clients:
subscribe_client.subscribe(
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL)

client_ids = {client["ClientID"] for client in self.client_table()}

while set(available_resources_by_id.keys()) != client_ids:
for subscribe_client in subscribe_clients:
# Parse client message
raw_message = subscribe_client.get_message()
if (raw_message is None or raw_message["channel"] !=
ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL):
continue
data = raw_message["data"]
gcs_entries = (
ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(
data, 0))
heartbeat_data = gcs_entries.Entries(0)
message = (ray.gcs_utils.HeartbeatTableData.
GetRootAsHeartbeatTableData(heartbeat_data, 0))
# Calculate available resources for this client
num_resources = message.ResourcesAvailableLabelLength()
dynamic_resources = {}
for i in range(num_resources):
resource_id = decode(
message.ResourcesAvailableLabel(i))
dynamic_resources[resource_id] = (
message.ResourcesAvailableCapacity(i))

# Update available resources for this client
client_id = ray.utils.binary_to_hex(message.ClientId())
available_resources_by_id[client_id] = dynamic_resources

# Update clients in cluster
client_ids = {
client["ClientID"]
for client in self.client_table()
}

# Remove disconnected clients
for client_id in available_resources_by_id.keys():
if client_id not in client_ids:
del available_resources_by_id[client_id]

# Calculate total available resources
total_available_resources = defaultdict(int)
for available_resources in available_resources_by_id.values():
for resource_id, num_available in available_resources.items():
total_available_resources[resource_id] += num_available

return dict(total_available_resources)

def _error_messages(self, job_id):
"""Get the error messages for a specific job.
Expand Down
12 changes: 12 additions & 0 deletions python/ray/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@
OBJECT_LOCATION_PREFIX = "OL:"
FUNCTION_PREFIX = "RemoteFunction:"

# These prefixes must be kept up-to-date with the definitions in
# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"

# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii")

# xray driver updates
XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii")

# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs.
# TODO(rkn): We should use scoped enums, in which case we should be able to
# just access the flatbuffer generated values.
Expand Down
32 changes: 10 additions & 22 deletions python/ray/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,6 @@
# common/task.h
TASK_STATUS_LOST = 32

# common/state/redis.cc
LOCAL_SCHEDULER_INFO_CHANNEL = b"local_schedulers"
PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers"
DRIVER_DEATH_CHANNEL = b"driver_deaths"

# xray heartbeats
XRAY_HEARTBEAT_CHANNEL = str(
ray.gcs_utils.TablePubsub.HEARTBEAT).encode("ascii")

# xray driver updates
XRAY_DRIVER_CHANNEL = str(ray.gcs_utils.TablePubsub.DRIVER).encode("ascii")

# common/redis_module/ray_redis_module.cc
OBJECT_INFO_PREFIX = b"OI:"
OBJECT_LOCATION_PREFIX = b"OL:"
Expand Down Expand Up @@ -607,23 +595,23 @@ def process_messages(self, max_messages=10000):

# Determine the appropriate message handler.
message_handler = None
if channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
if channel == ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL:
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == LOCAL_SCHEDULER_INFO_CHANNEL:
elif channel == ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL:
# The message was a heartbeat from a local scheduler
message_handler = self.local_scheduler_info_handler
elif channel == DB_CLIENT_TABLE_NAME:
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
elif channel == ray.gcs_utils.DRIVER_DEATH_CHANNEL:
# The message was a notification that a driver was removed.
logger.info("message-handler: driver_removed_handler")
message_handler = self.driver_removed_handler
elif channel == XRAY_HEARTBEAT_CHANNEL:
elif channel == ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL:
# Similar functionality as local scheduler info channel
message_handler = self.xray_heartbeat_handler
elif channel == XRAY_DRIVER_CHANNEL:
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
else:
Expand Down Expand Up @@ -686,11 +674,11 @@ def run(self):
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
self.subscribe(XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(XRAY_DRIVER_CHANNEL)
self.subscribe(ray.gcs_utils.LOCAL_SCHEDULER_INFO_CHANNEL)
self.subscribe(ray.gcs_utils.PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(ray.gcs_utils.DRIVER_DEATH_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL, primary=False)
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)

# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.
Expand Down
58 changes: 58 additions & 0 deletions python/ray/test/test_global_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

import ray


def setup_module():
if not ray.worker.global_worker.connected:
ray.init(num_cpus=1)

# Finish initializing Ray. Otherwise available_resources() does not
# reflect resource use of submitted tasks
ray.get(cpu_task.remote(0))


@ray.remote(num_cpus=1)
def cpu_task(seconds):
time.sleep(seconds)


class TestAvailableResources(object):
timeout = 10

def test_no_tasks(self):
cluster_resources = ray.global_state.cluster_resources()
available_resources = ray.global_state.cluster_resources()
assert cluster_resources == available_resources

def test_replenish_resources(self):
cluster_resources = ray.global_state.cluster_resources()

ray.get(cpu_task.remote(0))
start = time.time()
resources_reset = False

while not resources_reset and time.time() - start < self.timeout:
resources_reset = (
cluster_resources == ray.global_state.available_resources())

assert resources_reset

def test_uses_resources(self):
cluster_resources = ray.global_state.cluster_resources()
task_id = cpu_task.remote(1)
start = time.time()
resource_used = False

while not resource_used and time.time() - start < self.timeout:
available_resources = ray.global_state.available_resources()
resource_used = available_resources[
"CPU"] == cluster_resources["CPU"] - 1

assert resource_used

ray.get(task_id) # clean up to reset resources
7 changes: 1 addition & 6 deletions python/ray/test/test_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray.experimental.queue import Queue, Empty, Full


def start_ray():
def setup_module():
if not ray.worker.global_worker.connected:
ray.init()

Expand All @@ -28,7 +28,6 @@ def put_async(queue, item, block, timeout, sleep):


def test_simple_use():
start_ray()
q = Queue()

items = list(range(10))
Expand All @@ -41,7 +40,6 @@ def test_simple_use():


def test_async():
start_ray()
q = Queue()

items = set(range(10))
Expand All @@ -56,7 +54,6 @@ def test_async():


def test_put():
start_ray()
q = Queue(1)

item = 0
Expand Down Expand Up @@ -87,7 +84,6 @@ def test_put():


def test_get():
start_ray()
q = Queue()

item = 0
Expand All @@ -113,7 +109,6 @@ def test_get():


def test_qsize():
start_ray()
q = Queue()

items = list(range(10))
Expand Down

0 comments on commit 5da6e78

Please sign in to comment.