Skip to content

Commit

Permalink
[serve] Add basic REST API to dashboard (ray-project#22257)
Browse files Browse the repository at this point in the history
  • Loading branch information
edoakes authored Feb 15, 2022
1 parent 9c07eab commit f37f35c
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 58 deletions.
62 changes: 10 additions & 52 deletions dashboard/modules/job/job_head.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import aiohttp.web
from aiohttp.web import Request, Response
import dataclasses
from functools import wraps
import logging
from typing import Any, Callable
from typing import Any
import json
import traceback
from dataclasses import dataclass

import ray
import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as dashboard_optional_utils
from ray._private.gcs_utils import use_gcs_for_bootstrap
import ray.dashboard.optional_utils as optional_utils
from ray._private.runtime_env.packaging import package_exists, upload_package_to_gcs
from ray.dashboard.modules.job.common import (
CURRENT_VERSION,
Expand All @@ -30,47 +28,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

routes = dashboard_optional_utils.ClassMethodRouteTable

RAY_INTERNAL_JOBS_NAMESPACE = "_ray_internal_jobs"


def _init_ray_and_catch_exceptions(f: Callable) -> Callable:
@wraps(f)
async def check(self, *args, **kwargs):
try:
if not ray.is_initialized():
try:
if use_gcs_for_bootstrap():
address = self._dashboard_head.gcs_address
redis_pw = None
logger.info(f"Connecting to ray with address={address}")
else:
ip, port = self._dashboard_head.redis_address
redis_pw = self._dashboard_head.redis_password
address = f"{ip}:{port}"
logger.info(
f"Connecting to ray with address={address}, "
f"redis_pw={redis_pw}"
)
ray.init(
address=address,
namespace=RAY_INTERNAL_JOBS_NAMESPACE,
_redis_password=redis_pw,
)
except Exception as e:
ray.shutdown()
raise e from None

return await f(self, *args, **kwargs)
except Exception as e:
logger.exception(f"Unexpected error in handler: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code,
)

return check
routes = optional_utils.ClassMethodRouteTable


class JobHead(dashboard_utils.DashboardHeadModule):
Expand Down Expand Up @@ -113,7 +71,7 @@ async def get_version(self, req: Request) -> Response:
)

@routes.get("/api/packages/{protocol}/{package_name}")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def get_package(self, req: Request) -> Response:
package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"],
Expand All @@ -129,7 +87,7 @@ async def get_package(self, req: Request) -> Response:
return Response()

@routes.put("/api/packages/{protocol}/{package_name}")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def upload_package(self, req: Request):
package_uri = http_uri_components_to_uri(
protocol=req.match_info["protocol"],
Expand All @@ -147,7 +105,7 @@ async def upload_package(self, req: Request):
return Response(status=aiohttp.web.HTTPOk.status_code)

@routes.post("/api/jobs/")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def submit_job(self, req: Request) -> Response:
result = await self._parse_and_validate_request(req, JobSubmitRequest)
# Request parsing failed, returned with Response object.
Expand Down Expand Up @@ -183,7 +141,7 @@ async def submit_job(self, req: Request) -> Response:
)

@routes.post("/api/jobs/{job_id}/stop")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def stop_job(self, req: Request) -> Response:
job_id = req.match_info["job_id"]
if not self.job_exists(job_id):
Expand All @@ -206,7 +164,7 @@ async def stop_job(self, req: Request) -> Response:
)

@routes.get("/api/jobs/{job_id}")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def get_job_status(self, req: Request) -> Response:
job_id = req.match_info["job_id"]
if not self.job_exists(job_id):
Expand All @@ -222,7 +180,7 @@ async def get_job_status(self, req: Request) -> Response:
)

@routes.get("/api/jobs/{job_id}/logs")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def get_job_logs(self, req: Request) -> Response:
job_id = req.match_info["job_id"]
if not self.job_exists(job_id):
Expand All @@ -237,7 +195,7 @@ async def get_job_logs(self, req: Request) -> Response:
)

@routes.get("/api/jobs/{job_id}/logs/tail")
@_init_ray_and_catch_exceptions
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=False)
async def tail_job_logs(self, req: Request) -> Response:
job_id = req.match_info["job_id"]
if not self.job_exists(job_id):
Expand Down
Empty file.
76 changes: 76 additions & 0 deletions dashboard/modules/serve/serve_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import aiohttp.web
from aiohttp.web import Request, Response
import json
import logging
from typing import Optional

import ray.dashboard.utils as dashboard_utils
import ray.dashboard.optional_utils as optional_utils

from ray import serve

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

routes = optional_utils.ClassMethodRouteTable


class ServeHead(dashboard_utils.DashboardHeadModule):
def __init__(self, dashboard_head):
super().__init__(dashboard_head)

def _get_deployment_by_name(self, name: str) -> Optional[serve.api.Deployment]:
try:
return serve.get_deployment(name)
except KeyError:
return None

@routes.get("/api/serve/deployments/")
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def get_all_deployments(self, req: Request) -> Response:
dict_response = {
name: str(deployment)
for name, deployment in serve.list_deployments().items()
}

return Response(text=json.dumps(dict_response), content_type="application/json")

@routes.get("/api/serve/deployments/{name}")
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def get_single_deployment(self, req: Request) -> Response:
name = req.match_info["name"]
deployment = self._get_deployment_by_name(name)
if deployment is None:
return Response(
text=f"Deployment {name} does not exist.",
status=aiohttp.web.HTTPNotFound.status_code,
)
return Response(
text=json.dumps(str(deployment)), content_type="application/json"
)

@routes.delete("/api/serve/deployments/")
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def delete_all_deployments(self, req: Request) -> Response:
serve.shutdown()

@routes.delete("/api/serve/deployments/{name}")
@optional_utils.init_ray_and_catch_exceptions(connect_to_serve=True)
async def delete_single_deployment(self, req: Request) -> Response:
name = req.match_info["name"]
deployment = self._get_deployment_by_name(name)
if deployment is None:
return Response(
text=f"Deployment {name} does not exist.",
status=aiohttp.web.HTTPNotFound.status_code,
)

deployment.delete()
return Response()

async def run(self, server):
pass

@staticmethod
def is_minimal_module():
return False
57 changes: 56 additions & 1 deletion dashboard/optional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Optional utils module contains utility methods
that require optional dependencies.
"""
from aiohttp.web import Response
import asyncio
import collections
import functools
Expand All @@ -12,10 +13,14 @@
import time
import traceback
from collections import namedtuple
from typing import Any
from typing import Any, Callable

import ray
import ray.dashboard.consts as dashboard_consts
from ray.ray_constants import env_bool
from ray._private.gcs_utils import use_gcs_for_bootstrap

from ray import serve

try:
create_task = asyncio.create_task
Expand All @@ -30,6 +35,8 @@

logger = logging.getLogger(__name__)

RAY_INTERNAL_DASHBOARD_NAMESPACE = "_ray_internal_dashboard"


class ClassMethodRouteTable:
"""A helper class to bind http route to class method."""
Expand Down Expand Up @@ -242,3 +249,51 @@ def _update_cache(task):
return _wrapper(target_func)
else:
return _wrapper


def init_ray_and_catch_exceptions(connect_to_serve: bool = False) -> Callable:
"""Decorator to be used on methods that require being connected to Ray."""

def decorator_factory(f: Callable) -> Callable:
@functools.wraps(f)
async def decorator(self, *args, **kwargs):
try:
if not ray.is_initialized():
try:
if use_gcs_for_bootstrap():
address = self._dashboard_head.gcs_address
redis_pw = None
logger.info(f"Connecting to ray with address={address}")
else:
ip, port = self._dashboard_head.redis_address
redis_pw = self._dashboard_head.redis_password
address = f"{ip}:{port}"
logger.info(
f"Connecting to ray with address={address}, "
f"redis_pw={redis_pw}"
)
ray.init(
address=address,
namespace=RAY_INTERNAL_DASHBOARD_NAMESPACE,
_redis_password=redis_pw,
)
except Exception as e:
ray.shutdown()
raise e from None

if connect_to_serve:
# TODO(edoakes): this should probably run in the `serve`
# namespace.
serve.start(detached=True)

return await f(self, *args, **kwargs)
except Exception as e:
logger.exception(f"Unexpected error in handler: {e}")
return Response(
text=traceback.format_exc(),
status=aiohttp.web.HTTPInternalServerError.status_code,
)

return decorator

return decorator_factory
8 changes: 4 additions & 4 deletions dashboard/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import abc
from abc import ABCMeta, abstractmethod
import asyncio
from base64 import b64decode
from collections import namedtuple
from collections.abc import MutableMapping, Mapping, Sequence
import datetime
import functools
import importlib
import json
import logging
import pkgutil
import socket
from abc import ABCMeta, abstractmethod
from base64 import b64decode
from collections import namedtuple
from collections.abc import MutableMapping, Mapping, Sequence

import aioredis # noqa: F401
import aiosignal # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def start(
http_options = HTTPOptions()

controller = ServeController.options(
num_cpus=(1 if dedicated_cpu else 0),
num_cpus=1 if dedicated_cpu else 0,
name=controller_name,
lifetime="detached" if detached else None,
max_restarts=-1,
Expand Down

0 comments on commit f37f35c

Please sign in to comment.