Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Detect node updates #9828

Merged
merged 6 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,10 @@ def get_handle(endpoint_name,
if not missing_ok:
assert endpoint_name in ray.get(controller.get_all_endpoints.remote())

# TODO(edoakes): we should choose the router on the same node.
routers = ray.get(controller.get_routers.remote())
return RayServeHandle(
ray.get(controller.get_router.remote())[0],
list(routers.values())[0],
endpoint_name,
relative_slo_ms,
absolute_slo_ms,
Expand Down
171 changes: 108 additions & 63 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
from collections import defaultdict, namedtuple
from itertools import groupby
import os
import random
import time
Expand All @@ -15,7 +14,7 @@
from ray.serve.metric.exporter import MetricExporterActor
from ray.serve.exceptions import RayServeException
from ray.serve.utils import (format_actor_name, get_random_letters, logger,
try_schedule_resources_on_nodes)
try_schedule_resources_on_nodes, get_all_node_ids)

import numpy as np

Expand All @@ -28,6 +27,9 @@
# error if the desired replicas exceed current resource availability.
_RESOURCE_CHECK_ENABLED = True

# How often to call the control loop on the controller.
CONTROL_LOOP_PERIOD_S = 1.0


class TrafficPolicy:
def __init__(self, traffic_dict):
Expand Down Expand Up @@ -88,7 +90,7 @@ class ServeController:
requires all implementations here to be idempotent.
"""

async def __init__(self, instance_name, http_proxy_host, http_proxy_port,
async def __init__(self, instance_name, http_host, http_port,
metric_exporter_class):
# Unique name of the serve instance managed by this actor. Used to
# namespace child actors and checkpoints.
Expand Down Expand Up @@ -122,13 +124,17 @@ async def __init__(self, instance_name, http_proxy_host, http_proxy_port,
self.write_lock = asyncio.Lock()

# Cached handles to actors in the system.
self.routers = []
# node_id -> actor_handle
self.routers = dict()
self.metric_exporter = None

self.http_host = http_host
self.http_port = http_port

# If starting the actor for the first time, starts up the other system
# components. If recovering, fetches their actor handles.
self._get_or_start_metric_exporter(metric_exporter_class)
self._get_or_start_routers(http_proxy_host, http_proxy_port)
self._start_metric_exporter(metric_exporter_class)
self._start_routers_if_needed()

# NOTE(edoakes): unfortunately, we can't completely recover from a
# checkpoint in the constructor because we block while waiting for
Expand All @@ -150,46 +156,69 @@ async def __init__(self, instance_name, http_proxy_host, http_proxy_port,
asyncio.get_event_loop().create_task(
self._recover_from_checkpoint(checkpoint))

def _get_or_start_routers(self, host, port):
"""Get the HTTP proxy belonging to this serve instance.
asyncio.get_event_loop().create_task(self.run_control_loop())

def _start_routers_if_needed(self):
"""Start a router on every node if it doesn't already exist."""
for node_id, node_resource in get_all_node_ids():
if node_id in self.routers:
continue

If the HTTP proxy does not already exist, it will be started.
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
try:
router = ray.get_actor(router_name)
except ValueError:
logger.info("Starting router with name '{}' on node '{}' "
"listening on '{}:{}'".format(
router_name, node_id, self.http_host,
self.http_port))
router = HTTPProxyActor.options(
name=router_name,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_resource: 0.01
},
).remote(
self.http_host,
self.http_port,
instance_name=self.instance_name)

self.routers[node_id] = router

def _stop_routers_if_needed(self):
"""Removes router actors from any nodes that no longer exist.

Returns whether or not any actors were removed (a checkpoint should
be taken).
"""
# TODO(simon): We don't handle nodes being added/removed. To do that,
# we should implement some sort of control loop in master actor.
for _, node_id_group in groupby(sorted(ray.state.node_ids())):
for index, node_id in enumerate(node_id_group):
proxy_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name)
proxy_name += "-{}-{}".format(node_id, index)
try:
router = ray.get_actor(proxy_name)
except ValueError:
logger.info(
"Starting HTTP proxy with name '{}' on node '{}' "
"listening on port {}".format(proxy_name, node_id,
port))
router = HTTPProxyActor.options(
name=proxy_name,
max_concurrency=ASYNC_CONCURRENCY,
max_restarts=-1,
max_task_retries=-1,
resources={
node_id: 0.01
},
).remote(
host, port, instance_name=self.instance_name)
self.routers.append(router)

def get_router(self):
"""Returns a handle to the HTTP proxy managed by this actor."""
checkpoint_required = False
all_node_ids = {node_id for node_id, _ in get_all_node_ids()}
to_stop = []
for node_id in self.routers:
if node_id not in all_node_ids:
logger.info(
"Removing router on removed node '{}'.".format(node_id))
to_stop.append(node_id)

for node_id in to_stop:
router_handle = self.routers.pop(node_id)
ray.kill(router_handle, no_restart=True)
checkpoint_required = True

return checkpoint_required

def get_routers(self):
"""Returns a dictionary of node ID to router actor handles."""
return self.routers

def get_router_config(self):
"""Called by the HTTP proxy on startup to fetch required state."""
"""Called by the router on startup to fetch required state."""
return self.routes

def _get_or_start_metric_exporter(self, metric_exporter_class):
def _start_metric_exporter(self, metric_exporter_class):
"""Get the metric exporter belonging to this serve instance.

If the metric exporter does not already exist, it will be started.
Expand All @@ -210,11 +239,13 @@ def get_metric_exporter(self):

def _checkpoint(self):
"""Checkpoint internal state and write it to the KV store."""
assert self.write_lock.locked()
logger.debug("Writing checkpoint")
start = time.time()
checkpoint = pickle.dumps(
(self.routes, self.backends, self.traffic_policies, self.replicas,
self.replicas_to_start, self.replicas_to_stop,
(self.routes, list(
self.routers.keys()), self.backends, self.traffic_policies,
self.replicas, self.replicas_to_start, self.replicas_to_stop,
self.backends_to_remove, self.endpoints_to_remove))

self.kv_store.put(CHECKPOINT_KEY, checkpoint)
Expand All @@ -229,7 +260,7 @@ async def _recover_from_checkpoint(self, checkpoint_bytes):

Performs the following operations:
1) Deserializes the internal state from the checkpoint.
2) Pushes the latest configuration to the HTTP proxy and router
2) Pushes the latest configuration to the routers
in case we crashed before updating them.
3) Starts/stops any worker replicas that are pending creation or
deletion.
Expand All @@ -245,6 +276,7 @@ async def _recover_from_checkpoint(self, checkpoint_bytes):
# Load internal state from the checkpoint data.
(
self.routes,
router_node_ids,
self.backends,
self.traffic_policies,
self.replicas,
Expand All @@ -254,6 +286,11 @@ async def _recover_from_checkpoint(self, checkpoint_bytes):
self.endpoints_to_remove,
) = pickle.loads(checkpoint_bytes)

for node_id in router_node_ids:
router_name = format_actor_name(SERVE_PROXY_NAME,
self.instance_name, node_id)
self.routers[node_id] = ray.get_actor(router_name)

# Fetch actor handles for all of the backend replicas in the system.
# All of these workers are guaranteed to already exist because they
# would not be written to a checkpoint in self.workers until they
Expand All @@ -270,28 +307,28 @@ async def _recover_from_checkpoint(self, checkpoint_bytes):
for endpoint, traffic_policy in self.traffic_policies.items():
await asyncio.gather(*[
router.set_traffic.remote(endpoint, traffic_policy)
for router in self.routers
for router in self.routers.values()
])

for backend_tag, replica_dict in self.workers.items():
for replica_tag, worker in replica_dict.items():
await asyncio.gather(*[
router.add_new_worker.remote(backend_tag, replica_tag,
worker)
for router in self.routers
for router in self.routers.values()
])

for backend, info in self.backends.items():
await asyncio.gather(*[
router.set_backend_config.remote(backend, info.backend_config)
for router in self.routers
for router in self.routers.values()
])
await self.broadcast_backend_config(backend)

# Push configuration state to the HTTP proxy.
# Push configuration state to the routers.
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])

# Start/stop any pending backend replicas.
Expand All @@ -307,6 +344,16 @@ async def _recover_from_checkpoint(self, checkpoint_bytes):

self.write_lock.release()

async def run_control_loop(self):
while True:
async with self.write_lock:
self._start_routers_if_needed()
checkpoint_required = self._stop_routers_if_needed()
if checkpoint_required:
self._checkpoint()

await asyncio.sleep(CONTROL_LOOP_PERIOD_S)

def get_backend_configs(self):
"""Fetched by the router on startup."""
backend_configs = {}
Expand Down Expand Up @@ -368,7 +415,7 @@ async def _start_replica(self, backend_tag, replica_tag):
await asyncio.gather(*[
router.add_new_worker.remote(backend_tag, replica_tag,
worker_handle)
for router in self.routers
for router in self.routers.values()
])

async def _start_pending_replicas(self):
Expand Down Expand Up @@ -409,7 +456,7 @@ async def _stop_pending_replicas(self):
# Remove the replica from router. This call is idempotent.
await asyncio.gather(*[
router.remove_worker.remote(backend_tag, replica_tag)
for router in self.routers
for router in self.routers.values()
])

# TODO(edoakes): this logic isn't ideal because there may be
Expand All @@ -429,7 +476,7 @@ async def _remove_pending_backends(self):
for backend_tag in self.backends_to_remove:
await asyncio.gather(*[
router.remove_backend.remote(backend_tag)
for router in self.routers
for router in self.routers.values()
])
self.backends_to_remove.clear()

Expand All @@ -441,7 +488,7 @@ async def _remove_pending_endpoints(self):
for endpoint_tag in self.endpoints_to_remove:
await asyncio.gather(*[
router.remove_endpoint.remote(endpoint_tag)
for router in self.routers
for router in self.routers.values()
])
self.endpoints_to_remove.clear()

Expand Down Expand Up @@ -558,7 +605,7 @@ async def _set_traffic(self, endpoint_name, traffic_dict):
self._checkpoint()
await asyncio.gather(*[
router.set_traffic.remote(endpoint_name, traffic_policy)
for router in self.routers
for router in self.routers.values()
])

async def set_traffic(self, endpoint_name, traffic_dict):
Expand Down Expand Up @@ -590,14 +637,14 @@ async def shadow_traffic(self, endpoint_name, backend_tag, proportion):
router.set_traffic.remote(
endpoint_name,
self.traffic_policies[endpoint_name],
) for router in self.routers
) for router in self.routers.values()
])

async def create_endpoint(self, endpoint, traffic_dict, route, methods):
"""Create a new endpoint with the specified route and methods.

If the route is None, this is a "headless" endpoint that will not
be added to the HTTP proxy (can only be accessed via a handle).
be exposed over HTTP and can only be accessed via a handle.
"""
async with self.write_lock:
# If this is a headless endpoint with no route, key the endpoint
Expand Down Expand Up @@ -632,7 +679,7 @@ async def create_endpoint(self, endpoint, traffic_dict, route, methods):
await self._set_traffic(endpoint, traffic_dict)
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])

async def delete_endpoint(self, endpoint):
Expand Down Expand Up @@ -662,15 +709,13 @@ async def delete_endpoint(self, endpoint):
self.endpoints_to_remove.append(endpoint)

# NOTE(edoakes): we must write a checkpoint before pushing the
# updates to the HTTP proxy and router to avoid inconsistent state
# if we crash after pushing the update.
# updates to the routers to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()

# Update the HTTP proxy first to ensure no new requests for the
# endpoint are sent to the router.
await asyncio.gather(*[
router.set_route_table.remote(self.routes)
for router in self.routers
for router in self.routers.values()
])
await self._remove_pending_endpoints()

Expand Down Expand Up @@ -698,7 +743,7 @@ async def create_backend(self, backend_tag, backend_config,
# (particularly for max-batch-size).
await asyncio.gather(*[
router.set_backend_config.remote(backend_tag, backend_config)
for router in self.routers
for router in self.routers.values()
])
await self.broadcast_backend_config(backend_tag)

Expand Down Expand Up @@ -757,7 +802,7 @@ async def update_backend_config(self, backend_tag, config_options):
# (particularly for setting max_batch_size).
await asyncio.gather(*[
router.set_backend_config.remote(backend_tag, backend_config)
for router in self.routers
for router in self.routers.values()
])

await self._start_pending_replicas()
Expand Down Expand Up @@ -788,7 +833,7 @@ def get_backend_config(self, backend_tag):
async def shutdown(self):
"""Shuts down the serve instance completely."""
async with self.write_lock:
for router in self.routers:
for router in self.routers.values():
ray.kill(router, no_restart=True)
ray.kill(self.metric_exporter, no_restart=True)
for replica_dict in self.workers.values():
Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/tests/test_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def function():


def _kill_routers():
routers = ray.get(serve.api._get_controller().get_router.remote())
for router in routers:
routers = ray.get(serve.api._get_controller().get_routers.remote())
for router in routers.values():
ray.kill(router, no_restart=False)


Expand Down
Loading