Skip to content

Commit daa6d8d

Browse files
Qiaolin-Yudayshah
authored andcommitted
[core][RDT] Fix nixl garbage collection after the object is freed (ray-project#57138)
Signed-off-by: dayshah <dhyey2019@gmail.com> Co-authored-by: dayshah <dhyey2019@gmail.com> Signed-off-by: Josh Kodi <joshkodi@gmail.com>
1 parent 6e29165 commit daa6d8d

File tree

7 files changed

+61
-16
lines changed

7 files changed

+61
-16
lines changed

python/ray/experimental/collective/collective_tensor_transport.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,7 @@ def send_multiple_tensors(
176176
communicator_metadata.dst_rank,
177177
communicator_metadata.communicator_name,
178178
)
179+
180+
@staticmethod
181+
def garbage_collect(tensor_transport_meta: CollectiveTransportMetadata):
182+
pass

python/ray/experimental/collective/nixl_tensor_transport.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def extract_tensor_transport_metadata(
5454
device = None
5555
tensor_meta = []
5656
if gpu_object:
57-
serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(gpu_object)
57+
reg_descs, serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(
58+
gpu_object
59+
)
5860
# We assume all tensors in one GPU object have the same device type.
5961
device = gpu_object[0].device
6062
for t in gpu_object:
@@ -64,10 +66,11 @@ def extract_tensor_transport_metadata(
6466
)
6567
tensor_meta.append((t.shape, t.dtype))
6668
else:
67-
serialized_descs, agent_meta = None, None
69+
reg_descs, serialized_descs, agent_meta = None, None, None
6870
return NixlTransportMetadata(
6971
tensor_meta=tensor_meta,
7072
tensor_device=device,
73+
nixl_reg_descs=reg_descs,
7174
nixl_serialized_descs=serialized_descs,
7275
nixl_agent_meta=agent_meta,
7376
)
@@ -150,3 +153,12 @@ def send_multiple_tensors(
150153
raise NotImplementedError(
151154
"NIXL transport does not support send_multiple_tensors, since it is a one-sided transport."
152155
)
156+
157+
@staticmethod
158+
def garbage_collect(tensor_transport_meta: NixlTransportMetadata):
159+
from ray.util.collective.collective import get_group_handle
160+
from ray.util.collective.collective_group.nixl_backend import NixlBackend
161+
162+
descs = tensor_transport_meta.nixl_reg_descs
163+
nixl_backend: NixlBackend = get_group_handle(NIXL_GROUP_NAME)
164+
nixl_backend.deregister_memory(descs)

python/ray/experimental/collective/tensor_transport_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,13 @@ def send_multiple_tensors(
126126
tensors: The tensors to send.
127127
communicator_metadata: The communicator metadata for the send/recv operation.
128128
"""
129+
130+
@staticmethod
131+
@abstractmethod
132+
def garbage_collect(tensor_transport_meta: TensorTransportMetadata):
133+
"""
134+
Garbage collect for the tensor transport after the GPU object is freed.
135+
136+
Args:
137+
tensor_transport_meta: The tensor transport metadata.
138+
"""

python/ray/experimental/gpu_object_manager/gpu_object_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,17 +531,22 @@ def get_gpu_object(
531531

532532
def free_object_primary_copy(self, object_id: str):
533533
"""
534-
Free the primary copy of the GPU object.
534+
Free the primary copy of the GPU object. Expected to be idempotent when called from
535+
free_actor_object_callback because the primary copy holder should always only have one ref
536+
in the deque.
535537
"""
536538
from ray.experimental.gpu_object_manager.gpu_object_store import (
537539
__ray_free__,
538540
)
539541

540542
try:
541-
src_actor = self.managed_gpu_object_metadata[object_id].src_actor
542-
src_actor.__ray_call__.options(
543-
concurrency_group="_ray_system", max_task_retries=-1
544-
).remote(__ray_free__, object_id)
543+
gpu_object_meta = self.managed_gpu_object_metadata[object_id]
544+
src_actor = gpu_object_meta.src_actor
545+
tensor_transport_backend = gpu_object_meta.tensor_transport_backend
546+
tensor_transport_meta = gpu_object_meta.tensor_transport_meta
547+
src_actor.__ray_call__.options(concurrency_group="_ray_system").remote(
548+
__ray_free__, object_id, tensor_transport_backend, tensor_transport_meta
549+
)
545550
except Exception as e:
546551
logger.error(
547552
"Something went wrong while freeing the RDT object!", exc_info=e

python/ray/experimental/gpu_object_manager/gpu_object_store.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,20 @@ def __ray_recv__(
104104
gpu_object_store.add_object(obj_id, tensors)
105105

106106

107-
def __ray_free__(self, obj_id: str):
108-
"""
109-
Called on the primary copy holder. Note that the primary copy holder should always only have one ref
110-
in the gpu object store.
111-
"""
107+
def __ray_free__(
108+
self,
109+
obj_id: str,
110+
tensor_transport_backend: Backend,
111+
tensor_transport_meta: TensorTransportMetadata,
112+
):
112113
try:
113114
from ray._private.worker import global_worker
115+
from ray.experimental.collective import get_tensor_transport_manager
116+
117+
tensor_transport_manager = get_tensor_transport_manager(
118+
tensor_transport_backend
119+
)
120+
tensor_transport_manager.garbage_collect(tensor_transport_meta)
114121

115122
gpu_object_store = global_worker.gpu_object_manager.gpu_object_store
116123
gpu_object_store.pop_object(obj_id)

python/ray/util/collective/collective_group/nixl_backend.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import TYPE_CHECKING, List, Tuple
2+
from typing import TYPE_CHECKING, Any, List, Tuple
33

44
from nixl._api import nixl_agent, nixl_agent_config
55

@@ -87,9 +87,11 @@ def recv(
8787
break
8888

8989
nixl_agent.release_xfer_handle(xfer_handle)
90-
nixl_agent.deregister_memory(local_descs)
90+
nixl_agent.remove_remote_agent(remote_name)
9191

92-
def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes]:
92+
def get_nixl_metadata(
93+
self, tensors: List["torch.Tensor"]
94+
) -> Tuple[Any, bytes, bytes]:
9395
"""Get NIXL metadata for a set of tensors.
9496
9597
Args:
@@ -104,6 +106,10 @@ def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes
104106
reg_descs = nixl_agent.register_memory(tensors)
105107
xfer_descs = reg_descs.trim()
106108
return (
109+
reg_descs,
107110
nixl_agent.get_serialized_descs(xfer_descs),
108111
nixl_agent.get_agent_metadata(),
109112
)
113+
114+
def deregister_memory(self, descs: Any):
115+
self._nixl_agent.deregister_memory(descs)

python/ray/util/collective/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from datetime import timedelta
55
from enum import Enum
6-
from typing import TYPE_CHECKING, List, Optional, Tuple
6+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
77

88
from numpy import int32
99

@@ -80,6 +80,7 @@ class NixlTransportMetadata(TensorTransportMetadata):
8080
nixl_agent_meta: The additional metadata of the remote NIXL agent.
8181
"""
8282

83+
nixl_reg_descs: Optional[Any] = None
8384
nixl_serialized_descs: Optional[bytes] = None
8485
nixl_agent_meta: Optional[bytes] = None
8586

0 commit comments

Comments
 (0)