From b6bd4fd5f319e6a66f2efd43308cc810ba16a6c7 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Fri, 12 Nov 2021 09:02:27 -0800 Subject: [PATCH] [Serve] Don't recover from current state checkpoint (#19998) --- python/ray/serve/common.py | 34 +++++++++-- python/ray/serve/controller.py | 6 +- python/ray/serve/deployment_state.py | 59 ++++--------------- .../ray/serve/tests/test_deployment_state.py | 1 - python/ray/serve/tests/test_standalone.py | 32 ++++++++-- .../serve_cluster_fault_tolerance.py | 35 +++++------ 6 files changed, 87 insertions(+), 80 deletions(-) diff --git a/python/ray/serve/common.py b/python/ray/serve/common.py index d43c384f904f..8be7e47d97e5 100644 --- a/python/ray/serve/common.py +++ b/python/ray/serve/common.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional from uuid import UUID import ray -from ray.actor import ActorClass, ActorHandle +from ray.actor import ActorHandle from ray.serve.config import DeploymentConfig, ReplicaConfig from ray.serve.autoscaling_policy import AutoscalingPolicy @@ -25,7 +25,8 @@ def __init__(self, deployment_config: DeploymentConfig, replica_config: ReplicaConfig, start_time_ms: int, - actor_def: Optional[ActorClass] = None, + actor_name: Optional[str] = None, + serialized_deployment_def: Optional[bytes] = None, version: Optional[str] = None, deployer_job_id: "Optional[ray._raylet.JobID]" = None, end_time_ms: Optional[int] = None, @@ -34,13 +35,38 @@ def __init__(self, self.replica_config = replica_config # The time when .deploy() was first called for this deployment. self.start_time_ms = start_time_ms - self.actor_def = actor_def + self.actor_name = actor_name + self.serialized_deployment_def = serialized_deployment_def self.version = version self.deployer_job_id = deployer_job_id # The time when this deployment was deleted. self.end_time_ms = end_time_ms self.autoscaling_policy = autoscaling_policy + # ephermal state + self._cached_actor_def = None + + def __getstate__(self) -> Dict[Any, Any]: + clean_dict = self.__dict__.copy() + del clean_dict["_cached_actor_def"] + return clean_dict + + def __setstate__(self, d: Dict[Any, Any]) -> None: + self.__dict__ = d + self._cached_actor_def = None + + @property + def actor_def(self): + # Delayed import as replica depends on this file. + from ray.serve.replica import create_replica_wrapper + if self._cached_actor_def is None: + assert self.actor_name is not None + assert self.serialized_deployment_def is not None + self._cached_actor_def = ray.remote( + create_replica_wrapper(self.actor_name, + self.serialized_deployment_def)) + return self._cached_actor_def + @dataclass class ReplicaName: diff --git a/python/ray/serve/controller.py b/python/ray/serve/controller.py index 6c31655efe61..af5b3d8856b3 100644 --- a/python/ray/serve/controller.py +++ b/python/ray/serve/controller.py @@ -24,7 +24,6 @@ from ray.serve.constants import CONTROL_LOOP_PERIOD_S, SERVE_ROOT_URL_ENV_KEY from ray.serve.endpoint_state import EndpointState from ray.serve.http_state import HTTPState -from ray.serve.replica import create_replica_wrapper from ray.serve.storage.checkpoint_path import make_kv_store from ray.serve.long_poll import LongPollHost from ray.serve.storage.kv_store import RayInternalKVStore @@ -320,9 +319,8 @@ def deploy(self, autoscaling_policy = None deployment_info = DeploymentInfo( - actor_def=ray.remote( - create_replica_wrapper( - name, replica_config.serialized_deployment_def)), + actor_name=name, + serialized_deployment_def=replica_config.serialized_deployment_def, version=version, deployment_config=deployment_config, replica_config=replica_config, diff --git a/python/ray/serve/deployment_state.py b/python/ray/serve/deployment_state.py index 40d34642bd7a..ff25d93cd8c5 100644 --- a/python/ray/serve/deployment_state.py +++ b/python/ray/serve/deployment_state.py @@ -1,5 +1,6 @@ import math import json +import pickle import time from collections import defaultdict, OrderedDict from enum import Enum @@ -7,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import ray -from ray import cloudpickle, ObjectRef +from ray import ObjectRef from ray.actor import ActorHandle from ray.serve.async_goal_manager import AsyncGoalManager from ray.serve.common import (DeploymentInfo, Duration, GoalId, ReplicaTag, @@ -110,17 +111,6 @@ def __init__(self, actor_name: str, detached: bool, controller_name: str, # Populated in self.stop(). self._graceful_shutdown_ref: ObjectRef = None - def __get_state__(self) -> Dict[Any, Any]: - clean_dict = self.__dict__.copy() - del clean_dict["_ready_obj_ref"] - del clean_dict["_graceful_shutdown_ref"] - return clean_dict - - def __set_state__(self, d: Dict[Any, Any]) -> None: - self.__dict__ = d - self._ready_obj_ref = None - self._graceful_shutdown_ref = None - @property def replica_tag(self) -> str: return self._replica_tag @@ -372,12 +362,6 @@ def __init__(self, controller_name: str, detached: bool, self._start_time = None self._prev_slow_startup_warning_time = None - def __get_state__(self) -> Dict[Any, Any]: - return self.__dict__.copy() - - def __set_state__(self, d: Dict[Any, Any]) -> None: - self.__dict__ = d - def get_running_replica_info(self) -> RunningReplicaInfo: return RunningReplicaInfo( deployment_name=self._deployment_name, @@ -664,18 +648,8 @@ def get_target_state_checkpoint_data(self): """ return (self._target_info, self._target_replicas, self._target_version) - def get_current_state_checkpoint_data(self): - """ - Return deployment's current state specific to the ray cluster it's - running in. Might be lost or re-constructed upon ray cluster failure. - """ - return (self._rollback_info, self._curr_goal, - self._prev_startup_warning, - self._replica_constructor_retry_counter, self._replicas) - def get_checkpoint_data(self): - return (self.get_target_state_checkpoint_data(), - self.get_current_state_checkpoint_data()) + return self.get_target_state_checkpoint_data() def recover_target_state_from_checkpoint(self, target_state_checkpoint): logger.info("Recovering target state for deployment " @@ -683,18 +657,6 @@ def recover_target_state_from_checkpoint(self, target_state_checkpoint): (self._target_info, self._target_replicas, self._target_version) = target_state_checkpoint - def recover_current_state_from_checkpoint(self, current_state_checkpoint): - logger.info("Recovering current state for deployment " - f"{self._name} from checkpoint..") - (self._rollback_info, self._curr_goal, self._prev_startup_warning, - self._replica_constructor_retry_counter, - self._replicas) = current_state_checkpoint - - if self._curr_goal is not None: - self._goal_manager.create_goal(self._curr_goal) - - self._notify_running_replicas_changed() - def recover_current_state_from_replica_actor_names( self, replica_actor_names: List[str]): assert ( @@ -1288,23 +1250,19 @@ def _recover_from_checkpoint(self, checkpoint = self._kv_store.get(CHECKPOINT_KEY) if checkpoint is not None: (deployment_state_info, - self._deleted_deployment_metadata) = cloudpickle.loads(checkpoint) + self._deleted_deployment_metadata) = pickle.loads(checkpoint) for deployment_tag, checkpoint_data in deployment_state_info.items( ): deployment_state = self._create_deployment_state( deployment_tag) - (target_state_checkpoint, - current_state_checkpoint) = checkpoint_data + target_state_checkpoint = checkpoint_data deployment_state.recover_target_state_from_checkpoint( target_state_checkpoint) if len(deployment_to_current_replicas[deployment_tag]) > 0: deployment_state.recover_current_state_from_replica_actor_names( # noqa: E501 deployment_to_current_replicas[deployment_tag]) - else: - deployment_state.recover_current_state_from_checkpoint( - current_state_checkpoint) self._deployment_states[deployment_tag] = deployment_state def shutdown(self) -> List[GoalId]: @@ -1342,8 +1300,11 @@ def _save_checkpoint_func(self) -> None: } self._kv_store.put( CHECKPOINT_KEY, - cloudpickle.dumps((deployment_state_info, - self._deleted_deployment_metadata))) + # NOTE(simon): Make sure to use pickle so we don't save any ray + # object that relies on external state (e.g. gcs). For code object, + # we are explicitly using cloudpickle to serialize them. + pickle.dumps((deployment_state_info, + self._deleted_deployment_metadata))) def get_running_replica_infos( self, diff --git a/python/ray/serve/tests/test_deployment_state.py b/python/ray/serve/tests/test_deployment_state.py index 404621e53720..261f5a0581ca 100644 --- a/python/ray/serve/tests/test_deployment_state.py +++ b/python/ray/serve/tests/test_deployment_state.py @@ -154,7 +154,6 @@ def deployment_info(version: Optional[str] = None, user_config: Optional[Any] = None, **config_opts) -> Tuple[DeploymentInfo, DeploymentVersion]: info = DeploymentInfo( - actor_def=None, version=version, start_time_ms=0, deployment_config=DeploymentConfig( diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 05963614e62c..68fbc2104c81 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -520,28 +520,50 @@ def test_local_store_recovery(ray_shutdown): def hello(_): return "hello" - def check(): + # https://github.com/ray-project/ray/issues/19987 + @serve.deployment + def world(_): + return "world" + + def check(name): try: - resp = requests.get("http://localhost:8000/hello") - assert resp.text == "hello" + resp = requests.get(f"http://localhost:8000/{name}") + assert resp.text == name return True except Exception: return False + # https://github.com/ray-project/ray/issues/20159 + # https://github.com/ray-project/ray/issues/20158 + def clean_up_leaked_processes(): + import psutil + for proc in psutil.process_iter(): + try: + cmdline = " ".join(proc.cmdline()) + if "ray::" in cmdline: + print(f"Kill {proc} {cmdline}") + proc.kill() + except Exception: + pass + def crash(): subprocess.call(["ray", "stop", "--force"]) + clean_up_leaked_processes() ray.shutdown() serve.shutdown() serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}") hello.deploy() - assert check() + world.deploy() + assert check("hello") + assert check("world") crash() # Simulate a crash serve.start(detached=True, _checkpoint_path=f"file://{tmp_path}") - wait_for_condition(check) + wait_for_condition(lambda: check("hello")) + # wait_for_condition(lambda: check("world")) crash() diff --git a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py index 431c78b9c5df..696741ae3c61 100644 --- a/release/serve_tests/workloads/serve_cluster_fault_tolerance.py +++ b/release/serve_tests/workloads/serve_cluster_fault_tolerance.py @@ -12,6 +12,7 @@ import requests import uuid import os +from pathlib import Path from serve_test_cluster_utils import setup_local_single_node_cluster @@ -22,7 +23,7 @@ from ray.serve.utils import logger # Deployment configs -DEFAULT_NUM_REPLICAS = 4 +DEFAULT_NUM_REPLICAS = 2 DEFAULT_MAX_BATCH_SIZE = 16 @@ -49,7 +50,10 @@ def main(): # IS_SMOKE_TEST is set by args of releaser's e2e.py smoke_test = os.environ.get("IS_SMOKE_TEST", "1") if smoke_test == "1": - checkpoint_path = "file://checkpoint.db" + path = Path("checkpoint.db") + checkpoint_path = f"file://{path}" + if path.exists(): + path.unlink() else: checkpoint_path = "s3://serve-nightly-tests/fault-tolerant-test-checkpoint" # noqa: E501 @@ -57,20 +61,16 @@ def main(): 1, checkpoint_path=checkpoint_path, namespace=namespace) # Deploy for the first time - @serve.deployment(name="echo", num_replicas=DEFAULT_NUM_REPLICAS) - class Echo: - def __init__(self): - return True + @serve.deployment(num_replicas=DEFAULT_NUM_REPLICAS) + def hello(): + return serve.get_replica_context().deployment - def __call__(self, request): - return "hii" + for name in ["hello", "world"]: + hello.options(name=name).deploy() - Echo.deploy() - - # Ensure endpoint is working - for _ in range(5): - response = request_with_retries("/echo/", timeout=3) - assert response.text == "hii" + for _ in range(5): + response = request_with_retries(f"/{name}/", timeout=3) + assert response.text == name logger.info("Initial deployment successful with working endpoint.") @@ -87,9 +87,10 @@ def __call__(self, request): setup_local_single_node_cluster( 1, checkpoint_path=checkpoint_path, namespace=namespace) - for _ in range(5): - response = request_with_retries("/echo/", timeout=3) - assert response.text == "hii" + for name in ["hello", "world"]: + for _ in range(5): + response = request_with_retries(f"/{name}/", timeout=3) + assert response.text == name logger.info("Deployment recovery from s3 checkpoint is successful " "with working endpoint.")