Skip to content

Commit

Permalink
[serve] Clean up EndpointState interface, move checkpointing inside o…
Browse files Browse the repository at this point in the history
…f EndpointState (ray-project#13215)
  • Loading branch information
edoakes authored Jan 9, 2021
1 parent c5ae30d commit d434ba6
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 161 deletions.
6 changes: 3 additions & 3 deletions python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def delete_endpoint(self, endpoint: str) -> None:
Does not delete any associated backends.
"""
self._get_result(self._controller.delete_endpoint.remote(endpoint))
ray.get(self._controller.delete_endpoint.remote(endpoint))

@_ensure_connected
def list_endpoints(self) -> Dict[str, Dict[str, Any]]:
Expand Down Expand Up @@ -447,7 +447,7 @@ def set_traffic(self, endpoint_name: str,
traffic_policy_dictionary (dict): a dictionary maps backend names
to their traffic weights. The weights must sum to 1.
"""
self._get_result(
ray.get(
self._controller.set_traffic.remote(endpoint_name,
traffic_policy_dictionary))

Expand All @@ -473,7 +473,7 @@ def shadow_traffic(self, endpoint_name: str, backend_tag: str,
(float, int)) or not 0 <= proportion <= 1:
raise TypeError("proportion must be a float from 0 to 1.")

self._get_result(
ray.get(
self._controller.shadow_traffic.remote(endpoint_name, backend_tag,
proportion))

Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/backend_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@


class BackendState:
"""Manages all state for backends in the system.
This class is *not* thread safe, so any state-modifying methods should be
called with a lock held.
"""

def __init__(self,
controller_name: str,
detached: bool,
Expand Down
198 changes: 50 additions & 148 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,32 @@
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Optional
from uuid import UUID, uuid4
from typing import Dict, Any, List, Optional
from uuid import uuid4, UUID

import ray
import ray.cloudpickle as pickle
from ray.actor import ActorHandle
from ray.serve.backend_state import BackendState
from ray.serve.backend_worker import create_backend_replica
from ray.serve.common import (BackendInfo, BackendTag, EndpointTag, GoalId,
NodeId, ReplicaTag, TrafficPolicy)
from ray.serve.config import (BackendConfig, HTTPOptions, ReplicaConfig)
from ray.serve.constants import LongPollKey
from ray.serve.common import (
BackendInfo,
BackendTag,
EndpointTag,
GoalId,
NodeId,
ReplicaTag,
TrafficPolicy,
)
from ray.serve.config import BackendConfig, HTTPOptions, ReplicaConfig
from ray.serve.endpoint_state import EndpointState
from ray.serve.exceptions import RayServeException
from ray.serve.http_state import HTTPState
from ray.serve.kv_store import RayInternalKVStore
from ray.serve.long_poll import LongPollHost
from ray.serve.utils import logger

import ray

# Used for testing purposes only. If this is set, the controller will crash
# after writing each checkpoint with the specified probability.
_CRASH_AFTER_CHECKPOINT_PROBABILITY = 0
Expand All @@ -32,8 +38,6 @@
# How often to call the control loop on the controller.
CONTROL_LOOP_PERIOD_S = 1.0

REPLICA_STARTUP_TIME_WARNING_S = 5


@dataclass
class FutureResult:
Expand All @@ -43,7 +47,6 @@ class FutureResult:

@dataclass
class Checkpoint:
endpoint_state_checkpoint: bytes
backend_state_checkpoint: bytes
# TODO(ilr) Rename reconciler to PendingState
inflight_reqs: Dict[uuid4, FutureResult]
Expand Down Expand Up @@ -94,39 +97,34 @@ async def __init__(self,
self.inflight_results: Dict[UUID, asyncio.Event] = dict()
self._serializable_inflight_results: Dict[UUID, FutureResult] = dict()

# HTTP state doesn't currently require a checkpoint.
# NOTE(simon): Currently we do all-to-all broadcast. This means
# any listeners will receive notification for all changes. This
# can be problem at scale, e.g. updating a single backend config
# will send over the entire configs. In the future, we should
# optimize the logic to support subscription by key.
self.long_poll_host = LongPollHost()

self.http_state = HTTPState(controller_name, detached, http_config)
self.endpoint_state = EndpointState(self.kv_store, self.long_poll_host)

checkpoint_bytes = self.kv_store.get(CHECKPOINT_KEY)
if checkpoint_bytes is None:
logger.debug("No checkpoint found")
self.backend_state = BackendState(controller_name, detached)
self.endpoint_state = EndpointState()
else:
checkpoint: Checkpoint = pickle.loads(checkpoint_bytes)
self.backend_state = BackendState(
controller_name,
detached,
checkpoint=checkpoint.backend_state_checkpoint)
self.endpoint_state = EndpointState(
checkpoint=checkpoint.endpoint_state_checkpoint)

self._serializable_inflight_results = checkpoint.inflight_reqs
for uuid, fut_result in self._serializable_inflight_results.items(
):
self._create_event_with_result(fut_result.requested_goal, uuid)

# NOTE(simon): Currently we do all-to-all broadcast. This means
# any listeners will receive notification for all changes. This
# can be problem at scale, e.g. updating a single backend config
# will send over the entire configs. In the future, we should
# optimize the logic to support subscription by key.
self.long_poll_host = LongPollHost()

self.notify_backend_configs_changed()
self.notify_replica_handles_changed()
self.notify_traffic_policies_changed()
self.notify_route_table_changed()

asyncio.get_event_loop().create_task(self.run_control_loop())

Expand Down Expand Up @@ -168,21 +166,11 @@ def notify_replica_handles_changed(self):
self.backend_state.backend_replicas.items()
})

def notify_traffic_policies_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.TRAFFIC_POLICIES,
self.endpoint_state.traffic_policies,
)

def notify_backend_configs_changed(self):
self.long_poll_host.notify_changed(
LongPollKey.BACKEND_CONFIGS,
self.backend_state.get_backend_configs())

def notify_route_table_changed(self):
self.long_poll_host.notify_changed(LongPollKey.ROUTE_TABLE,
self.endpoint_state.routes)

async def listen_for_change(self, keys_to_snapshot_ids: Dict[str, int]):
"""Proxy long pull client's listen request.
Expand All @@ -205,8 +193,7 @@ def _checkpoint(self) -> None:
start = time.time()

checkpoint = pickle.dumps(
Checkpoint(self.endpoint_state.checkpoint(),
self.backend_state.checkpoint(),
Checkpoint(self.backend_state.checkpoint(),
self._serializable_inflight_results))

self.kv_store.put(CHECKPOINT_KEY, checkpoint)
Expand Down Expand Up @@ -254,155 +241,74 @@ def get_all_endpoints(self) -> Dict[EndpointTag, Dict[BackendTag, Any]]:
"""Returns a dictionary of backend tag to backend config."""
return self.endpoint_state.get_endpoints()

async def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to assign traffic for an endpoint '{}'"
" that is not registered.".format(endpoint_name))

assert isinstance(traffic_dict,
dict), "Traffic policy must be a dictionary."

def _set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
for backend in traffic_dict:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))

traffic_policy = TrafficPolicy(traffic_dict)
self.endpoint_state.traffic_policies[endpoint_name] = traffic_policy
self.endpoint_state.set_traffic_policy(endpoint_name,
TrafficPolicy(traffic_dict))

return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
def _validate_traffic_dict(self, traffic_dict: Dict[str, float]):
for backend in traffic_dict:
if self.backend_state.get_backend(backend) is None:
raise ValueError(
"Attempted to assign traffic to a backend '{}' that "
"is not registered.".format(backend))

async def set_traffic(self, endpoint_name: str,
traffic_dict: Dict[str, float]) -> UUID:
traffic_dict: Dict[str, float]) -> None:
"""Sets the traffic policy for the specified endpoint."""
async with self.write_lock:
return_uuid = await self._set_traffic(endpoint_name, traffic_dict)
return return_uuid
self._validate_traffic_dict(traffic_dict)
self._set_traffic(endpoint_name, traffic_dict)

async def shadow_traffic(self, endpoint_name: str, backend_tag: BackendTag,
proportion: float) -> UUID:
"""Shadow traffic from the endpoint to the backend."""
async with self.write_lock:
if endpoint_name not in self.endpoint_state.get_endpoints():
raise ValueError("Attempted to shadow traffic from an "
"endpoint '{}' that is not registered."
.format(endpoint_name))

if self.backend_state.get_backend(backend_tag) is None:
raise ValueError(
"Attempted to shadow traffic to a backend '{}' that "
"is not registered.".format(backend_tag))

self.endpoint_state.traffic_policies[endpoint_name].set_shadow(
backend_tag, proportion)

traffic_policy = self.endpoint_state.traffic_policies[
endpoint_name]
logger.info(
"Shadowing '{}' of traffic to endpoint '{}' to backend '{}'.".
format(proportion, endpoint_name, backend_tag))

return_uuid = self._create_event_with_result({
endpoint_name: traffic_policy
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# update to avoid inconsistent state if we crash after pushing the
# update.
self._checkpoint()
self.notify_traffic_policies_changed()
self.set_goal_id(return_uuid)
return return_uuid
self.endpoint_state.shadow_traffic(endpoint_name, backend_tag,
proportion)

# TODO(architkulkarni): add Optional for route after cloudpickle upgrade
async def create_endpoint(self, endpoint: str,
traffic_dict: Dict[str, float], route,
methods) -> UUID:
methods: List[str]) -> UUID:
"""Create a new endpoint with the specified route and methods.
If the route is None, this is a "headless" endpoint that will not
be exposed over HTTP and can only be accessed via a handle.
"""
async with self.write_lock:
# If this is a headless endpoint with no route, key the endpoint
# based on its name.
# TODO(edoakes): we should probably just store routes and endpoints
# separately.
if route is None:
route = endpoint

# TODO(edoakes): move this to client side.
err_prefix = "Cannot create endpoint."
if route in self.endpoint_state.routes:

# Ensures this method is idempotent
if self.endpoint_state.routes[route] == (endpoint, methods):
return

else:
raise ValueError(
"{} Route '{}' is already registered.".format(
err_prefix, route))

if endpoint in self.endpoint_state.get_endpoints():
raise ValueError(
"{} Endpoint '{}' is already registered.".format(
err_prefix, endpoint))
self._validate_traffic_dict(traffic_dict)

logger.info(
"Registering route '{}' to endpoint '{}' with methods '{}'.".
format(route, endpoint, methods))

self.endpoint_state.routes[route] = (endpoint, methods)
self.endpoint_state.create_endpoint(endpoint, route, methods,
TrafficPolicy(traffic_dict))

# NOTE(edoakes): checkpoint is written in self._set_traffic.
return_uuid = await self._set_traffic(endpoint, traffic_dict)
self.notify_route_table_changed()
return return_uuid

async def delete_endpoint(self, endpoint: str) -> UUID:
async def delete_endpoint(self, endpoint: str) -> None:
"""Delete the specified endpoint.
Does not modify any corresponding backends.
"""
logger.info("Deleting endpoint '{}'".format(endpoint))
async with self.write_lock:
# This method must be idempotent. We should validate that the
# specified endpoint exists on the client.
for route, (route_endpoint,
_) in self.endpoint_state.routes.items():
if route_endpoint == endpoint:
route_to_delete = route
break
else:
logger.info("Endpoint '{}' doesn't exist".format(endpoint))
return

# Remove the routing entry.
del self.endpoint_state.routes[route_to_delete]

# Remove the traffic policy entry if it exists.
if endpoint in self.endpoint_state.traffic_policies:
del self.endpoint_state.traffic_policies[endpoint]

return_uuid = self._create_event_with_result({
route_to_delete: None,
endpoint: None
})
# NOTE(edoakes): we must write a checkpoint before pushing the
# updates to the proxies to avoid inconsistent state if we crash
# after pushing the update.
self._checkpoint()
self.notify_route_table_changed()
self.set_goal_id(return_uuid)
return return_uuid
self.endpoint_state.delete_endpoint(endpoint)

async def set_backend_goal(self, backend_tag: BackendTag,
backend_info: BackendInfo,
Expand Down Expand Up @@ -466,19 +372,15 @@ async def delete_backend(self,
return

# Check that the specified backend isn't used by any endpoints.
for endpoint, traffic_policy in self.endpoint_state.\
traffic_policies.items():
if (backend_tag in traffic_policy.traffic_dict
or backend_tag in traffic_policy.shadow_dict):
for endpoint, info in self.endpoint_state.get_endpoints().items():
if (backend_tag in info["traffic"]
or backend_tag in info["shadows"]):
raise ValueError("Backend '{}' is used by endpoint '{}' "
"and cannot be deleted. Please remove "
"the backend from all endpoints and try "
"again.".format(backend_tag, endpoint))

# Scale its replicas down to 0. This will also remove the backend
# from self.backend_state.backends and

# This should be a call to the control loop
# Scale its replicas down to 0.
self.backend_state.scale_backend_replicas(backend_tag, 0,
force_kill)

Expand Down
Loading

0 comments on commit d434ba6

Please sign in to comment.