Skip to content

Commit

Permalink
[Serve] Make handle serializable (ray-project#22473)
Browse files Browse the repository at this point in the history
  • Loading branch information
architkulkarni authored and simonsays1980 committed Feb 27, 2022
1 parent 7d31d62 commit f0146b0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/ray/serve/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,19 @@ async def remote(self, *args, **kwargs):
def __repr__(self):
return f"{self.__class__.__name__}" f"(deployment='{self.deployment_name}')"

@classmethod
def _deserialize(cls, kwargs):
"""Required for this class's __reduce__ method to be picklable."""
return cls(**kwargs)

def __reduce__(self):
serialized_data = {
"controller_handle": self.controller_handle,
"deployment_name": self.deployment_name,
"handle_options": self.handle_options,
"_internal_pickled_http_request": self._pickled_http_request,
}
return lambda kwargs: RayServeHandle(**kwargs), (serialized_data,)
return RayServeHandle._deserialize, (serialized_data,)

def __getattr__(self, name):
return self.options(method_name=name)
Expand Down Expand Up @@ -228,4 +233,4 @@ def __reduce__(self):
"handle_options": self.handle_options,
"_internal_pickled_http_request": self._pickled_http_request,
}
return lambda kwargs: RayServeSyncHandle(**kwargs), (serialized_data,)
return RayServeSyncHandle._deserialize, (serialized_data,)
25 changes: 25 additions & 0 deletions python/ray/serve/tests/test_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,31 @@ def task(handle):
assert ray.get(result_ref) == "hello"


def test_handle_serializable_in_deployment_init(serve_instance):
"""Test that a handle can be passed into a constructor (#22110)"""

@serve.deployment
class RayServer1:
def __init__(self):
pass

def __call__(self, *args):
return {"count": self.count}

@serve.deployment
class RayServer2:
def __init__(self, handle):
self.handle = handle

def __call__(self, *args):
return {"count": self.count}

RayServer1.deploy()
for sync in [True, False]:
rs1_handle = RayServer1.get_handle(sync=sync)
RayServer2.deploy(rs1_handle)


def test_sync_handle_in_thread(serve_instance):
@serve.deployment
def f():
Expand Down

0 comments on commit f0146b0

Please sign in to comment.