Skip to content

Commit e46fe7e

Browse files
authored
more core code cherrypicks (#57557)
cherrypick #57247 #57253 #57138 Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
1 parent 8f3bc35 commit e46fe7e

File tree

9 files changed

+114
-30
lines changed

9 files changed

+114
-30
lines changed

python/ray/_raylet.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ cdef extern from "Python.h":
7272
ctypedef struct CPyThreadState "PyThreadState":
7373
int recursion_limit
7474
int recursion_remaining
75+
int c_recursion_remaining
7576

7677
# From Include/ceveal.h#67
7778
int Py_GetRecursionLimit()

python/ray/_raylet.pyx

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -679,29 +679,63 @@ def compute_task_id(ObjectRef object_ref):
679679

680680

681681
cdef increase_recursion_limit():
682-
"""Double the recusion limit if current depth is close to the limit"""
682+
"""
683+
Ray does some weird things with asio fibers and asyncio to run asyncio actors.
684+
This results in the Python interpreter thinking there's a lot of recursion depth,
685+
so we need to increase the limit when we start getting close.
686+
687+
0x30C0000 is Python 3.12
688+
On 3.12, when recursion depth increases, c_recursion_remaining will decrease,
689+
and that's what's actually compared to raise a RecursionError. So increasing
690+
it by 1000 when it drops below 1000 will keep us from raising the RecursionError.
691+
https://github.com/python/cpython/blob/bfb9e2f4a4e690099ec2ec53c08b90f4d64fde36/Python/pystate.c#L1353
692+
0x30B00A4 is Python 3.11
693+
On 3.11, the recursion depth can be calculated with recursion_limit - recursion_remaining.
694+
We can get the current limit with Py_GetRecursionLimit and set it with Py_SetRecursionLimit.
695+
We'll double the limit when there's less than 500 remaining.
696+
On older versions
697+
There's simply a recursion_depth variable and we'll increase the max the same
698+
way we do for 3.11.
699+
"""
683700
cdef:
684-
CPyThreadState * s = <CPyThreadState *> PyThreadState_Get()
685-
int current_limit = Py_GetRecursionLimit()
686-
int new_limit = current_limit * 2
687701
cdef extern from *:
688702
"""
689703
#if PY_VERSION_HEX >= 0x30C0000
690-
#define CURRENT_DEPTH(x) ((x)->py_recursion_limit - (x)->py_recursion_remaining)
704+
bool IncreaseRecursionLimitIfNeeded(PyThreadState *x) {
705+
if (x->c_recursion_remaining < 1000) {
706+
x->c_recursion_remaining += 1000;
707+
return true;
708+
}
709+
return false;
710+
}
691711
#elif PY_VERSION_HEX >= 0x30B00A4
692-
#define CURRENT_DEPTH(x) ((x)->recursion_limit - (x)->recursion_remaining)
712+
bool IncreaseRecursionLimitIfNeeded(PyThreadState *x) {
713+
int current_limit = Py_GetRecursionLimit();
714+
int current_depth = x->recursion_limit - x->recursion_remaining;
715+
if (current_limit - current_depth < 500) {
716+
Py_SetRecursionLimit(current_limit * 2);
717+
return true;
718+
}
719+
return false;
720+
}
693721
#else
694-
#define CURRENT_DEPTH(x) ((x)->recursion_depth)
722+
bool IncreaseRecursionLimitIfNeeded(PyThreadState *x) {
723+
int current_limit = Py_GetRecursionLimit();
724+
if (current_limit - x->recursion_depth < 500) {
725+
Py_SetRecursionLimit(current_limit * 2);
726+
return true;
727+
}
728+
return false;
729+
}
695730
#endif
696731
"""
697-
int CURRENT_DEPTH(CPyThreadState *x)
732+
c_bool IncreaseRecursionLimitIfNeeded(CPyThreadState *x)
733+
734+
CPyThreadState * s = <CPyThreadState *> PyThreadState_Get()
735+
c_bool increased_recursion_limit = IncreaseRecursionLimitIfNeeded(s)
698736

699-
int current_depth = CURRENT_DEPTH(s)
700-
if current_limit - current_depth < 500:
701-
Py_SetRecursionLimit(new_limit)
702-
logger.debug("Increasing Python recursion limit to {} "
703-
"current recursion depth is {}.".format(
704-
new_limit, current_depth))
737+
if increased_recursion_limit:
738+
logger.debug("Increased Python recursion limit")
705739

706740

707741
cdef CObjectLocationPtrToDict(CObjectLocation* c_object_location):
@@ -2462,6 +2496,10 @@ cdef CRayStatus task_execution_handler(
24622496
if hasattr(e, "unexpected_error_traceback"):
24632497
msg += (f" {e.unexpected_error_traceback}")
24642498
return CRayStatus.UnexpectedSystemExit(msg)
2499+
except Exception as e:
2500+
msg = "Unexpected exception raised in task execution handler: {}".format(e)
2501+
logger.error(msg)
2502+
return CRayStatus.UnexpectedSystemExit(msg)
24652503

24662504
return CRayStatus.OK()
24672505

python/ray/experimental/collective/collective_tensor_transport.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ def send_multiple_tensors(
176176
communicator_metadata.dst_rank,
177177
communicator_metadata.communicator_name,
178178
)
179+
180+
@staticmethod
181+
def garbage_collect(tensor_transport_meta: CollectiveTransportMetadata):
182+
pass

python/ray/experimental/collective/nixl_tensor_transport.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def extract_tensor_transport_metadata(
5454
device = None
5555
tensor_meta = []
5656
if gpu_object:
57-
serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(gpu_object)
57+
reg_descs, serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(
58+
gpu_object
59+
)
5860
# We assume all tensors in one GPU object have the same device type.
5961
device = gpu_object[0].device
6062
for t in gpu_object:
@@ -64,10 +66,11 @@ def extract_tensor_transport_metadata(
6466
)
6567
tensor_meta.append((t.shape, t.dtype))
6668
else:
67-
serialized_descs, agent_meta = None, None
69+
reg_descs, serialized_descs, agent_meta = None, None, None
6870
return NixlTransportMetadata(
6971
tensor_meta=tensor_meta,
7072
tensor_device=device,
73+
nixl_reg_descs=reg_descs,
7174
nixl_serialized_descs=serialized_descs,
7275
nixl_agent_meta=agent_meta,
7376
)
@@ -150,3 +153,12 @@ def send_multiple_tensors(
150153
raise NotImplementedError(
151154
"NIXL transport does not support send_multiple_tensors, since it is a one-sided transport."
152155
)
156+
157+
@staticmethod
158+
def garbage_collect(tensor_transport_meta: NixlTransportMetadata):
159+
from ray.util.collective.collective import get_group_handle
160+
from ray.util.collective.collective_group.nixl_backend import NixlBackend
161+
162+
descs = tensor_transport_meta.nixl_reg_descs
163+
nixl_backend: NixlBackend = get_group_handle(NIXL_GROUP_NAME)
164+
nixl_backend.deregister_memory(descs)

python/ray/experimental/collective/tensor_transport_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,13 @@ def send_multiple_tensors(
126126
tensors: The tensors to send.
127127
communicator_metadata: The communicator metadata for the send/recv operation.
128128
"""
129+
130+
@staticmethod
131+
@abstractmethod
132+
def garbage_collect(tensor_transport_meta: TensorTransportMetadata):
133+
"""
134+
Garbage collect for the tensor transport after the GPU object is freed.
135+
136+
Args:
137+
tensor_transport_meta: The tensor transport metadata.
138+
"""

python/ray/experimental/gpu_object_manager/gpu_object_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,17 +531,22 @@ def get_gpu_object(
531531

532532
def free_object_primary_copy(self, object_id: str):
533533
"""
534-
Free the primary copy of the GPU object.
534+
Free the primary copy of the GPU object. Expected to be idempotent when called from
535+
free_actor_object_callback because the primary copy holder should always only have one ref
536+
in the deque.
535537
"""
536538
from ray.experimental.gpu_object_manager.gpu_object_store import (
537539
__ray_free__,
538540
)
539541

540542
try:
541-
src_actor = self.managed_gpu_object_metadata[object_id].src_actor
542-
src_actor.__ray_call__.options(
543-
concurrency_group="_ray_system", max_task_retries=-1
544-
).remote(__ray_free__, object_id)
543+
gpu_object_meta = self.managed_gpu_object_metadata[object_id]
544+
src_actor = gpu_object_meta.src_actor
545+
tensor_transport_backend = gpu_object_meta.tensor_transport_backend
546+
tensor_transport_meta = gpu_object_meta.tensor_transport_meta
547+
src_actor.__ray_call__.options(concurrency_group="_ray_system").remote(
548+
__ray_free__, object_id, tensor_transport_backend, tensor_transport_meta
549+
)
545550
except Exception as e:
546551
logger.error(
547552
"Something went wrong while freeing the RDT object!", exc_info=e

python/ray/experimental/gpu_object_manager/gpu_object_store.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,20 @@ def __ray_recv__(
104104
gpu_object_store.add_object(obj_id, tensors)
105105

106106

107-
def __ray_free__(self, obj_id: str):
108-
"""
109-
Called on the primary copy holder. Note that the primary copy holder should always only have one ref
110-
in the gpu object store.
111-
"""
107+
def __ray_free__(
108+
self,
109+
obj_id: str,
110+
tensor_transport_backend: Backend,
111+
tensor_transport_meta: TensorTransportMetadata,
112+
):
112113
try:
113114
from ray._private.worker import global_worker
115+
from ray.experimental.collective import get_tensor_transport_manager
116+
117+
tensor_transport_manager = get_tensor_transport_manager(
118+
tensor_transport_backend
119+
)
120+
tensor_transport_manager.garbage_collect(tensor_transport_meta)
114121

115122
gpu_object_store = global_worker.gpu_object_manager.gpu_object_store
116123
gpu_object_store.pop_object(obj_id)

python/ray/util/collective/collective_group/nixl_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import TYPE_CHECKING, List, Tuple
2+
from typing import TYPE_CHECKING, Any, List, Tuple
33

44
from nixl._api import nixl_agent, nixl_agent_config
55

@@ -87,9 +87,11 @@ def recv(
8787
break
8888

8989
nixl_agent.release_xfer_handle(xfer_handle)
90-
nixl_agent.deregister_memory(local_descs)
90+
nixl_agent.remove_remote_agent(remote_name)
9191

92-
def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes]:
92+
def get_nixl_metadata(
93+
self, tensors: List["torch.Tensor"]
94+
) -> Tuple[Any, bytes, bytes]:
9395
"""Get NIXL metadata for a set of tensors.
9496
9597
Args:
@@ -104,6 +106,10 @@ def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes
104106
reg_descs = nixl_agent.register_memory(tensors)
105107
xfer_descs = reg_descs.trim()
106108
return (
109+
reg_descs,
107110
nixl_agent.get_serialized_descs(xfer_descs),
108111
nixl_agent.get_agent_metadata(),
109112
)
113+
114+
def deregister_memory(self, descs: Any):
115+
self._nixl_agent.deregister_memory(descs)

python/ray/util/collective/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from datetime import timedelta
55
from enum import Enum
6-
from typing import TYPE_CHECKING, List, Optional, Tuple
6+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
77

88
from numpy import int32
99

@@ -80,6 +80,7 @@ class NixlTransportMetadata(TensorTransportMetadata):
8080
nixl_agent_meta: The additional metadata of the remote NIXL agent.
8181
"""
8282

83+
nixl_reg_descs: Optional[Any] = None
8384
nixl_serialized_descs: Optional[bytes] = None
8485
nixl_agent_meta: Optional[bytes] = None
8586

0 commit comments

Comments
 (0)