Skip to content

Commit

Permalink
[Serve] Don't recover from current state checkpoint (ray-project#19998)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Nov 12, 2021
1 parent ce8504b commit b6bd4fd
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 80 deletions.
34 changes: 30 additions & 4 deletions python/ray/serve/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, Optional
from uuid import UUID

import ray
from ray.actor import ActorClass, ActorHandle
from ray.actor import ActorHandle
from ray.serve.config import DeploymentConfig, ReplicaConfig
from ray.serve.autoscaling_policy import AutoscalingPolicy

Expand All @@ -25,7 +25,8 @@ def __init__(self,
deployment_config: DeploymentConfig,
replica_config: ReplicaConfig,
start_time_ms: int,
actor_def: Optional[ActorClass] = None,
actor_name: Optional[str] = None,
serialized_deployment_def: Optional[bytes] = None,
version: Optional[str] = None,
deployer_job_id: "Optional[ray._raylet.JobID]" = None,
end_time_ms: Optional[int] = None,
Expand All @@ -34,13 +35,38 @@ def __init__(self,
self.replica_config = replica_config
# The time when .deploy() was first called for this deployment.
self.start_time_ms = start_time_ms
self.actor_def = actor_def
self.actor_name = actor_name
self.serialized_deployment_def = serialized_deployment_def
self.version = version
self.deployer_job_id = deployer_job_id
# The time when this deployment was deleted.
self.end_time_ms = end_time_ms
self.autoscaling_policy = autoscaling_policy

# ephermal state
self._cached_actor_def = None

def __getstate__(self) -> Dict[Any, Any]:
clean_dict = self.__dict__.copy()
del clean_dict["_cached_actor_def"]
return clean_dict

def __setstate__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d
self._cached_actor_def = None

@property
def actor_def(self):
# Delayed import as replica depends on this file.
from ray.serve.replica import create_replica_wrapper
if self._cached_actor_def is None:
assert self.actor_name is not None
assert self.serialized_deployment_def is not None
self._cached_actor_def = ray.remote(
create_replica_wrapper(self.actor_name,
self.serialized_deployment_def))
return self._cached_actor_def


@dataclass
class ReplicaName:
Expand Down
6 changes: 2 additions & 4 deletions python/ray/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY
from ray.serve.endpoint_state import EndpointState
from ray.serve.http_state import HTTPState
from ray.serve.replica import create_replica_wrapper
from ray.serve.storage.checkpoint_path import make_kv_store
from ray.serve.long_poll import LongPollHost
from ray.serve.storage.kv_store import RayInternalKVStore
Expand Down Expand Up @@ -320,9 +319,8 @@ def deploy(self,
autoscaling_policy = None

deployment_info = DeploymentInfo(
actor_def=ray.remote(
create_replica_wrapper(
name, replica_config.serialized_deployment_def)),
actor_name=name,
serialized_deployment_def=replica_config.serialized_deployment_def,
version=version,
deployment_config=deployment_config,
replica_config=replica_config,
Expand Down
59 changes: 10 additions & 49 deletions python/ray/serve/deployment_state.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import math
import json
import pickle
import time
from collections import defaultdict, OrderedDict
from enum import Enum
import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import ray
from ray import cloudpickle, ObjectRef
from ray import ObjectRef
from ray.actor import ActorHandle
from ray.serve.async_goal_manager import AsyncGoalManager
from ray.serve.common import (DeploymentInfo, Duration, GoalId, ReplicaTag,
Expand Down Expand Up @@ -110,17 +111,6 @@ def __init__(self, actor_name: str, detached: bool, controller_name: str,
# Populated in self.stop().
self._graceful_shutdown_ref: ObjectRef = None

def __get_state__(self) -> Dict[Any, Any]:
clean_dict = self.__dict__.copy()
del clean_dict["_ready_obj_ref"]
del clean_dict["_graceful_shutdown_ref"]
return clean_dict

def __set_state__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d
self._ready_obj_ref = None
self._graceful_shutdown_ref = None

@property
def replica_tag(self) -> str:
return self._replica_tag
Expand Down Expand Up @@ -372,12 +362,6 @@ def __init__(self, controller_name: str, detached: bool,
self._start_time = None
self._prev_slow_startup_warning_time = None

def __get_state__(self) -> Dict[Any, Any]:
return self.__dict__.copy()

def __set_state__(self, d: Dict[Any, Any]) -> None:
self.__dict__ = d

def get_running_replica_info(self) -> RunningReplicaInfo:
return RunningReplicaInfo(
deployment_name=self._deployment_name,
Expand Down Expand Up @@ -664,37 +648,15 @@ def get_target_state_checkpoint_data(self):
"""
return (self._target_info, self._target_replicas, self._target_version)

def get_current_state_checkpoint_data(self):
"""
Return deployment's current state specific to the ray cluster it's
running in. Might be lost or re-constructed upon ray cluster failure.
"""
return (self._rollback_info, self._curr_goal,
self._prev_startup_warning,
self._replica_constructor_retry_counter, self._replicas)

def get_checkpoint_data(self):
return (self.get_target_state_checkpoint_data(),
self.get_current_state_checkpoint_data())
return self.get_target_state_checkpoint_data()

def recover_target_state_from_checkpoint(self, target_state_checkpoint):
logger.info("Recovering target state for deployment "
f"{self._name} from checkpoint..")
(self._target_info, self._target_replicas,
self._target_version) = target_state_checkpoint

def recover_current_state_from_checkpoint(self, current_state_checkpoint):
logger.info("Recovering current state for deployment "
f"{self._name} from checkpoint..")
(self._rollback_info, self._curr_goal, self._prev_startup_warning,
self._replica_constructor_retry_counter,
self._replicas) = current_state_checkpoint

if self._curr_goal is not None:
self._goal_manager.create_goal(self._curr_goal)

self._notify_running_replicas_changed()

def recover_current_state_from_replica_actor_names(
self, replica_actor_names: List[str]):
assert (
Expand Down Expand Up @@ -1288,23 +1250,19 @@ def _recover_from_checkpoint(self,
checkpoint = self._kv_store.get(CHECKPOINT_KEY)
if checkpoint is not None:
(deployment_state_info,
self._deleted_deployment_metadata) = cloudpickle.loads(checkpoint)
self._deleted_deployment_metadata) = pickle.loads(checkpoint)

for deployment_tag, checkpoint_data in deployment_state_info.items(
):
deployment_state = self._create_deployment_state(
deployment_tag)
(target_state_checkpoint,
current_state_checkpoint) = checkpoint_data

target_state_checkpoint = checkpoint_data
deployment_state.recover_target_state_from_checkpoint(
target_state_checkpoint)
if len(deployment_to_current_replicas[deployment_tag]) > 0:
deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501
deployment_to_current_replicas[deployment_tag])
else:
deployment_state.recover_current_state_from_checkpoint(
current_state_checkpoint)
self._deployment_states[deployment_tag] = deployment_state

def shutdown(self) -> List[GoalId]:
Expand Down Expand Up @@ -1342,8 +1300,11 @@ def _save_checkpoint_func(self) -> None:
}
self._kv_store.put(
CHECKPOINT_KEY,
cloudpickle.dumps((deployment_state_info,
self._deleted_deployment_metadata)))
# NOTE(simon): Make sure to use pickle so we don't save any ray
# object that relies on external state (e.g. gcs). For code object,
# we are explicitly using cloudpickle to serialize them.
pickle.dumps((deployment_state_info,
self._deleted_deployment_metadata)))

def get_running_replica_infos(
self,
Expand Down
1 change: 0 additions & 1 deletion python/ray/serve/tests/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def deployment_info(version: Optional[str] = None,
user_config: Optional[Any] = None,
**config_opts) -> Tuple[DeploymentInfo, DeploymentVersion]:
info = DeploymentInfo(
actor_def=None,
version=version,
start_time_ms=0,
deployment_config=DeploymentConfig(
Expand Down
32 changes: 27 additions & 5 deletions python/ray/serve/tests/test_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,28 +520,50 @@ def test_local_store_recovery(ray_shutdown):
def hello(_):
return "hello"

def check():
# https://github.com/ray-project/ray/issues/19987
@serve.deployment
def world(_):
return "world"

def check(name):
try:
resp = requests.get("http://localhost:8000/hello")
assert resp.text == "hello"
resp = requests.get(f"http://localhost:8000/{name}")
assert resp.text == name
return True
except Exception:
return False

# https://github.com/ray-project/ray/issues/20159
# https://github.com/ray-project/ray/issues/20158
def clean_up_leaked_processes():
import psutil
for proc in psutil.process_iter():
try:
cmdline = " ".join(proc.cmdline())
if "ray::" in cmdline:
print(f"Kill {proc} {cmdline}")
proc.kill()
except Exception:
pass

def crash():
subprocess.call(["ray", "stop", "--force"])
clean_up_leaked_processes()
ray.shutdown()
serve.shutdown()

serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
hello.deploy()
assert check()
world.deploy()
assert check("hello")
assert check("world")
crash()

# Simulate a crash

serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}")
wait_for_condition(check)
wait_for_condition(lambda: check("hello"))
# wait_for_condition(lambda: check("world"))
crash()


Expand Down
35 changes: 18 additions & 17 deletions release/serve_tests/workloads/serve_cluster_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import requests
import uuid
import os
from pathlib import Path

from serve_test_cluster_utils import setup_local_single_node_cluster

Expand All @@ -22,7 +23,7 @@
from ray.serve.utils import logger

# Deployment configs
DEFAULT_NUM_REPLICAS = 4
DEFAULT_NUM_REPLICAS = 2
DEFAULT_MAX_BATCH_SIZE = 16


Expand All @@ -49,28 +50,27 @@ def main():
# IS_SMOKE_TEST is set by args of releaser's e2e.py
smoke_test = os.environ.get("IS_SMOKE_TEST", "1")
if smoke_test == "1":
checkpoint_path = "file://checkpoint.db"
path = Path("checkpoint.db")
checkpoint_path = f"file://{path}"
if path.exists():
path.unlink()
else:
checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501

_, cluster = setup_local_single_node_cluster(
1, checkpoint_path=checkpoint_path, namespace=namespace)

# Deploy for the first time
@serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS)
class Echo:
def __init__(self):
return True
@serve.deployment(num_replicas=DEFAULT_NUM_REPLICAS)
def hello():
return serve.get_replica_context().deployment

def __call__(self, request):
return "hii"
for name in ["hello", "world"]:
hello.options(name=name).deploy()

Echo.deploy()

# Ensure endpoint is working
for _ in range(5):
response = request_with_retries("/echo/", timeout=3)
assert response.text == "hii"
for _ in range(5):
response = request_with_retries(f"/{name}/", timeout=3)
assert response.text == name

logger.info("Initial deployment successful with working endpoint.")

Expand All @@ -87,9 +87,10 @@ def __call__(self, request):
setup_local_single_node_cluster(
1, checkpoint_path=checkpoint_path, namespace=namespace)

for _ in range(5):
response = request_with_retries("/echo/", timeout=3)
assert response.text == "hii"
for name in ["hello", "world"]:
for _ in range(5):
response = request_with_retries(f"/{name}/", timeout=3)
assert response.text == name

logger.info("Deployment recovery from s3 checkpoint is successful "
"with working endpoint.")
Expand Down

0 comments on commit b6bd4fd

Please sign in to comment.