Skip to content

Commit

Permalink
[Serve] Refactor RequestMetadata and Query objects (#10483)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Sep 2, 2020
1 parent 3b10b67 commit 65f17f2
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 136 deletions.
7 changes: 4 additions & 3 deletions python/ray/serve/backend_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import Iterable
from collections import defaultdict
from itertools import groupby
from operator import attrgetter
from typing import Union, List, Any, Callable, Type
import time

Expand Down Expand Up @@ -185,7 +184,7 @@ def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable,
asyncio.get_event_loop().create_task(self.main_loop())

def get_runner_method(self, request_item: Query) -> Callable:
method_name = request_item.call_method
method_name = request_item.metadata.call_method
if not hasattr(self.callable, method_name):
raise RayServeException("Backend doesn't have method {} "
"which is specified in the request. "
Expand Down Expand Up @@ -325,7 +324,9 @@ async def main_loop(self) -> None:
all_evaluated_futures = [evaluated]
chain_future(evaluated, query.async_future)
else:
get_call_method = attrgetter("call_method")
get_call_method = (
lambda query: query.metadata.call_method # noqa: E731
)
sorted_batch = sorted(batch, key=get_call_method)
for _, group in groupby(sorted_batch, key=get_call_method):
group = list(group)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/serve/endpoint_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def flush(self, endpoint_queue, backend_queues):
assigned_backends = set()
while len(endpoint_queue) > 0:
query = endpoint_queue.pop()
if query.shard_key is None:
if query.metadata.shard_key is None:
rstate = np.random
else:
sha256_seed = sha256(query.shard_key.encode("utf-8"))
sha256_seed = sha256(query.metadata.shard_key.encode("utf-8"))
seed = np.frombuffer(sha256_seed.digest(), dtype=np.uint32)
# Note(simon): This constructor takes 100+us, maybe cache this?
rstate = np.random.RandomState(seed)
Expand All @@ -93,7 +93,7 @@ def flush(self, endpoint_queue, backend_queues):
if len(shadow_backends) > 0:
shadow_query = copy.copy(query)
shadow_query.async_future = None
shadow_query.is_shadow_query = True
shadow_query.metadata.is_shadow_query = True
for shadow_backend in shadow_backends:
assigned_backends.add(shadow_backend)
backend_queues[shadow_backend].appendleft(shadow_query)
Expand Down
17 changes: 2 additions & 15 deletions python/ray/serve/examples/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,15 @@
Example service that prints out http context.
"""

import json
import time

from pygments import formatters, highlight, lexers

import requests

from ray import serve


def pformat_color_json(d):
"""Use pygments to pretty format and colorize dictionary"""
formatted_json = json.dumps(d, sort_keys=True, indent=4)

colorful_json = highlight(formatted_json, lexers.JsonLexer(),
formatters.TerminalFormatter())

return colorful_json


def echo(flask_request):
return "hello " + flask_request.args.get("name", "serve!")
return ["hello " + flask_request.args.get("name", "serve!")]


serve.init()
Expand All @@ -33,7 +20,7 @@ def echo(flask_request):

while True:
resp = requests.get("http://127.0.0.1:8000/echo").json()
print(pformat_color_json(resp))
print(resp)

print("...Sleeping for 2 seconds...")
time.sleep(2)
55 changes: 21 additions & 34 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional

import ray
from ray import serve
from ray.serve.context import TaskContext
from ray.serve.exceptions import RayServeException
from ray.serve.request_params import RequestMetadata


Expand All @@ -16,7 +17,6 @@ class RayServeHandle:
>>> handle
RayServeHandle(
Endpoint="my_endpoint",
URL="...",
Traffic=...
)
>>> handle.remote(my_request_content)
Expand All @@ -31,61 +31,48 @@ def __init__(
self,
router_handle,
endpoint_name,
http_method=None,
method_name=None,
shard_key=None,
):
self.router_handle = router_handle
self.endpoint_name = endpoint_name
self.http_method = http_method
self.method_name = method_name
self.shard_key = shard_key

def remote(self, *args, **kwargs):
if len(args) != 0:
raise RayServeException(
if len(args) > 0:
raise ValueError(
"handle.remote must be invoked with keyword arguments.")

method_name = self.method_name
if method_name is None:
method_name = "__call__"

# create RequestMetadata instance
request_in_object = RequestMetadata(
request_metadata = RequestMetadata(
self.endpoint_name,
TaskContext.Python,
call_method=method_name,
http_method=self.http_method or "GET",
call_method=self.method_name or "__call__",
shard_key=self.shard_key,
)
return self.router_handle.enqueue_request.remote(
request_in_object, **kwargs)

def options(self, method_name=None, shard_key=None):

# Don't override existing method
if method_name is None and self.method_name is not None:
method_name = self.method_name

if shard_key is None and self.shard_key is not None:
shard_key = self.shard_key
request_metadata, **kwargs)

def options(self,
method_name: Optional[str] = None,
http_method: Optional[str] = None,
shard_key: Optional[str] = None):
return RayServeHandle(
self.router_handle,
self.endpoint_name,
method_name=method_name,
shard_key=shard_key,
# Don't override existing method
http_method=self.http_method or http_method,
method_name=self.method_name or method_name,
shard_key=self.shard_key or shard_key,
)

def get_traffic_policy(self):
def _get_traffic_policy(self):
controller = serve.api._get_controller()
return ray.get(
controller.get_traffic_policy.remote(self.endpoint_name))

def __repr__(self):
return """
RayServeHandle(
Endpoint="{endpoint_name}",
Traffic={traffic_policy}
)
""".format(
endpoint_name=self.endpoint_name,
traffic_policy=self.get_traffic_policy(),
)
return (f"RayServeHandle(Endpoint='{self.endpoint_name}', "
f"Traffic={self._get_traffic_policy()})")
1 change: 1 addition & 0 deletions python/ray/serve/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def __call__(self, scope, receive, send):
request_metadata = RequestMetadata(
endpoint_name,
TaskContext.Web,
http_method=scope["method"].upper(),
call_method=headers.get("X-SERVE-CALL-METHOD".lower(), "__call__"),
shard_key=headers.get("X-SERVE-SHARD-KEY".lower(), None),
)
Expand Down
37 changes: 11 additions & 26 deletions python/ray/serve/request_params.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
import ray.cloudpickle as pickle
from dataclasses import dataclass
from typing import Optional

from ray.serve.context import TaskContext

class RequestMetadata:
"""
Request arguments required for enqueuing a request to the endpoint queue.
Args:
endpoint(str): A registered endpoint.
request_context(TaskContext): Context of a request.
"""

def __init__(self,
endpoint,
request_context,
call_method="__call__",
shard_key=None):

self.endpoint = endpoint
self.request_context = request_context
self.call_method = call_method
self.shard_key = shard_key

def ray_serialize(self):
return pickle.dumps(self.__dict__)
@dataclass
class RequestMetadata:
endpoint: str
request_context: TaskContext

@staticmethod
def ray_deserialize(value):
kwargs = pickle.loads(value)
return RequestMetadata(**kwargs)
call_method: str = "__call__"
shard_key: Optional[str] = None
http_method: str = "GET"
is_shadow_query: bool = False
57 changes: 13 additions & 44 deletions python/ray/serve/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,31 @@
import copy
from collections import defaultdict, deque
import time
from typing import DefaultDict, List
from typing import DefaultDict, List, Dict, Any, Optional
import pickle
from dataclasses import dataclass

from ray.exceptions import RayTaskError

import ray
from ray import serve
from ray.experimental import metrics
from ray.serve.context import TaskContext
from ray.serve.endpoint_policy import RandomEndpointPolicy
from ray.serve.request_params import RequestMetadata
from ray.serve.utils import logger, chain_future

REPORT_QUEUE_LENGTH_PERIOD_S = 1.0


@dataclass
class Query:
def __init__(
self,
request_args,
request_kwargs,
request_context,
call_method="__call__",
shard_key=None,
async_future=None,
is_shadow_query=False,
):
self.request_args = request_args
self.request_kwargs = request_kwargs
self.request_context = request_context
args: List[Any]
kwargs: Dict[Any, Any]
context: TaskContext

self.async_future = async_future

self.call_method = call_method
self.shard_key = shard_key
self.is_shadow_query = is_shadow_query
metadata: RequestMetadata
async_future: Optional[asyncio.Future] = None

def __reduce__(self):
return type(self).ray_deserialize, (self.ray_serialize(), )
Expand All @@ -56,27 +47,6 @@ def ray_deserialize(value):
return Query(**kwargs)


def _make_future_unwrapper(client_futures: List[asyncio.Future],
host_future: asyncio.Future):
"""Distribute the result of host_future to each of client_future"""
for client_future in client_futures:
# Keep a reference to host future so the host future won't get
# garbage collected.
client_future.host_ref = host_future

def unwrap_future(_):
result = host_future.result()

if isinstance(result, list):
for client_future, result_item in zip(client_futures, result):
client_future.set_result(result_item)
else: # Result is an exception.
for client_future in client_futures:
client_future.set_result(result)

return unwrap_future


class Router:
"""A router that routes request to available workers."""

Expand Down Expand Up @@ -175,8 +145,7 @@ async def enqueue_request(self, request_meta, *request_args,
request_args,
request_kwargs,
request_context,
call_method=request_meta.call_method,
shard_key=request_meta.shard_key,
metadata=request_meta,
async_future=asyncio.get_event_loop().create_future())
async with self.flush_lock:
self.endpoint_queues[endpoint].appendleft(query)
Expand Down Expand Up @@ -301,7 +270,7 @@ async def _do_query(self, backend, backend_replica_tag, req):
worker = self.replicas[backend_replica_tag]
try:
object_ref = worker.handle_request.remote(req.ray_serialize())
if req.is_shadow_query:
if req.metadata.is_shadow_query:
# No need to actually get the result, but we do need to wait
# until the call completes to mark the worker idle.
await asyncio.wait([object_ref])
Expand Down Expand Up @@ -351,7 +320,7 @@ def _assign_query_to_worker(self, backend, buffer_queue, worker_queue):
self._do_query(backend, backend_replica_tag, request))

# For shadow queries, just ignore the result.
if not request.is_shadow_query:
if not request.metadata.is_shadow_query:
chain_future(future, request.async_future)

worker_queue.appendleft(backend_replica_tag)
Expand Down
14 changes: 7 additions & 7 deletions python/ray/serve/tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ async def test_single_prod_cons_queue(serve_instance, task_runner_mock_actor):

# Make sure it's the right request
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
assert got_work.request_kwargs == {}
assert got_work.args[0] == 1
assert got_work.kwargs == {}


async def test_alter_backend(serve_instance, task_runner_mock_actor):
Expand All @@ -74,14 +74,14 @@ async def test_alter_backend(serve_instance, task_runner_mock_actor):
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 1)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 1
assert got_work.args[0] == 1

await q.set_traffic.remote("svc", TrafficPolicy({"backend-alter-2": 1}))
await q.add_new_worker.remote("backend-alter-2", "replica-1",
task_runner_mock_actor)
await q.enqueue_request.remote(RequestMetadata("svc", None), 2)
got_work = await task_runner_mock_actor.get_recent_call.remote()
assert got_work.request_args[0] == 2
assert got_work.args[0] == 2


async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
Expand All @@ -106,7 +106,7 @@ async def test_split_traffic_random(serve_instance, task_runner_mock_actor):
await runner.get_recent_call.remote()
for runner in (runner_1, runner_2)
]
assert [g.request_args[0] for g in got_work] == [1, 1]
assert [g.args[0] for g in got_work] == [1, 1]


async def test_queue_remove_replicas(serve_instance):
Expand Down Expand Up @@ -146,7 +146,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
runner_shard_keys[i].add(call.request_args[0])
runner_shard_keys[i].add(call.args[0])
await runner.clear_calls.remote()

# Send queries with the same shard keys a second time.
Expand All @@ -158,7 +158,7 @@ async def test_shard_key(serve_instance, task_runner_mock_actor):
for i, runner in enumerate(runners):
calls = await runner.get_all_calls.remote()
for call in calls:
assert call.request_args[0] in runner_shard_keys[i]
assert call.args[0] in runner_shard_keys[i]


async def test_router_use_max_concurrency(serve_instance):
Expand Down
Loading

0 comments on commit 65f17f2

Please sign in to comment.