Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,7 @@ def send_multiple_tensors(
communicator_metadata.dst_rank,
communicator_metadata.communicator_name,
)

@staticmethod
def garbage_collect(tensor_transport_meta: CollectiveTransportMetadata):
pass
16 changes: 14 additions & 2 deletions python/ray/experimental/collective/nixl_tensor_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def extract_tensor_transport_metadata(
device = None
tensor_meta = []
if gpu_object:
serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(gpu_object)
reg_descs, serialized_descs, agent_meta = nixl_backend.get_nixl_metadata(
gpu_object
)
# We assume all tensors in one GPU object have the same device type.
device = gpu_object[0].device
for t in gpu_object:
Expand All @@ -64,10 +66,11 @@ def extract_tensor_transport_metadata(
)
tensor_meta.append((t.shape, t.dtype))
else:
serialized_descs, agent_meta = None, None
reg_descs, serialized_descs, agent_meta = None, None, None
return NixlTransportMetadata(
tensor_meta=tensor_meta,
tensor_device=device,
nixl_reg_descs=reg_descs,
nixl_serialized_descs=serialized_descs,
nixl_agent_meta=agent_meta,
)
Expand Down Expand Up @@ -150,3 +153,12 @@ def send_multiple_tensors(
raise NotImplementedError(
"NIXL transport does not support send_multiple_tensors, since it is a one-sided transport."
)

@staticmethod
def garbage_collect(tensor_transport_meta: NixlTransportMetadata):
from ray.util.collective.collective import get_group_handle
from ray.util.collective.collective_group.nixl_backend import NixlBackend

descs = tensor_transport_meta.nixl_reg_descs
nixl_backend: NixlBackend = get_group_handle(NIXL_GROUP_NAME)
nixl_backend.deregister_memory(descs)
10 changes: 10 additions & 0 deletions python/ray/experimental/collective/tensor_transport_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,13 @@ def send_multiple_tensors(
tensors: The tensors to send.
communicator_metadata: The communicator metadata for the send/recv operation.
"""

@staticmethod
@abstractmethod
def garbage_collect(tensor_transport_meta: TensorTransportMetadata):
"""
Garbage collect for the tensor transport after the GPU object is freed.
Args:
tensor_transport_meta: The tensor transport metadata.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -531,17 +531,22 @@ def get_gpu_object(

def free_object_primary_copy(self, object_id: str):
"""
Free the primary copy of the GPU object.
Free the primary copy of the GPU object. Expected to be idempotent when called from
free_actor_object_callback because the primary copy holder should always only have one ref
in the deque.
"""
from ray.experimental.gpu_object_manager.gpu_object_store import (
__ray_free__,
)

try:
src_actor = self.managed_gpu_object_metadata[object_id].src_actor
src_actor.__ray_call__.options(
concurrency_group="_ray_system", max_task_retries=-1
).remote(__ray_free__, object_id)
gpu_object_meta = self.managed_gpu_object_metadata[object_id]
src_actor = gpu_object_meta.src_actor
tensor_transport_backend = gpu_object_meta.tensor_transport_backend
tensor_transport_meta = gpu_object_meta.tensor_transport_meta
src_actor.__ray_call__.options(concurrency_group="_ray_system").remote(
__ray_free__, object_id, tensor_transport_backend, tensor_transport_meta
)
except Exception as e:
logger.error(
"Something went wrong while freeing the RDT object!", exc_info=e
Expand Down
17 changes: 12 additions & 5 deletions python/ray/experimental/gpu_object_manager/gpu_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,20 @@ def __ray_recv__(
gpu_object_store.add_object(obj_id, tensors)


def __ray_free__(self, obj_id: str):
"""
Called on the primary copy holder. Note that the primary copy holder should always only have one ref
in the gpu object store.
"""
def __ray_free__(
self,
obj_id: str,
tensor_transport_backend: Backend,
tensor_transport_meta: TensorTransportMetadata,
):
try:
from ray._private.worker import global_worker
from ray.experimental.collective import get_tensor_transport_manager

tensor_transport_manager = get_tensor_transport_manager(
tensor_transport_backend
)
tensor_transport_manager.garbage_collect(tensor_transport_meta)

gpu_object_store = global_worker.gpu_object_manager.gpu_object_store
gpu_object_store.pop_object(obj_id)
Expand Down
12 changes: 9 additions & 3 deletions python/ray/util/collective/collective_group/nixl_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Any, List, Tuple

from nixl._api import nixl_agent, nixl_agent_config

Expand Down Expand Up @@ -87,9 +87,11 @@ def recv(
break

nixl_agent.release_xfer_handle(xfer_handle)
nixl_agent.deregister_memory(local_descs)
nixl_agent.remove_remote_agent(remote_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a little confused on why we would remove the agent but not deregister the memory if the send is sync

we'll re-register the memory on every send anyways

Copy link
Member Author

@Qiaolin-Yu Qiaolin-Yu Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deregister_memory should be called by the same agent which calls register_memory. In our case, it should be called by the sender, I added it in the gc function.


def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes]:
def get_nixl_metadata(
self, tensors: List["torch.Tensor"]
) -> Tuple[Any, bytes, bytes]:
"""Get NIXL metadata for a set of tensors.

Args:
Expand All @@ -104,6 +106,10 @@ def get_nixl_metadata(self, tensors: List["torch.Tensor"]) -> Tuple[bytes, bytes
reg_descs = nixl_agent.register_memory(tensors)
xfer_descs = reg_descs.trim()
return (
reg_descs,
nixl_agent.get_serialized_descs(xfer_descs),
nixl_agent.get_agent_metadata(),
)

def deregister_memory(self, descs: Any):
self._nixl_agent.deregister_memory(descs)
3 changes: 2 additions & 1 deletion python/ray/util/collective/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

from numpy import int32

Expand Down Expand Up @@ -80,6 +80,7 @@ class NixlTransportMetadata(TensorTransportMetadata):
nixl_agent_meta: The additional metadata of the remote NIXL agent.
"""

nixl_reg_descs: Optional[Any] = None
nixl_serialized_descs: Optional[bytes] = None
nixl_agent_meta: Optional[bytes] = None

Expand Down