Skip to content

Conversation

@benhg
Copy link

@benhg benhg commented Nov 11, 2025

Posting this here just for discussion and out of my own interest. This PR migrates from custom C/Python bindings to using nvshmem4py where it's easy/simple. I'll leave some comments with questions/comments around certain specific areas of the code.

uid_bytes = nvshmem_comm_cuda.NVSHMEMCommWrapper.get_unique_id_bytes()
uid_gpu = uid_bytes.to(device)
dist.broadcast(uid_gpu, src=0)
# Set device current
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be a helper function because it's used in the benchmarks and the library itself. I couldn't think of where the best place to put it would be.

if comm_wrapper is not None:
nvrar_tensor, nvrar_tensor_id = comm_wrapper.allocate_tensor(num_elems, dtype, device, nvshmem_comm_cuda.Protocol.LL8)
# Allocate symmetric tensor via nvshmem4py and register with wrapper
nvrar_tensor = nvshmem.tensor((num_elems,), dtype=dtype)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the first major difference. I couldn't think of a good way to handle the tensor_id stuff purely in python, so what I did is:

  • Replace tensor allocation with the nvshmem.core wrapper
  • keep the other parts of the process in your C code (and rename it to register_tensor instead of allocate_tensor)

# This should be idempotent
cuda_dev.set_current()
stream = torch.cuda.current_stream()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the same boilerplate

# Allow user override via -DCUDA_CCCL_INCLUDE_DIR
set(CUDA_CCCL_INCLUDE_DIR "" CACHE PATH "Path to CUDA CCCL include directory (contains cuda/std)")
set(_CUDA_ROOT "")
if(DEFINED ENV{CUDA_HOME})
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hacky and terrible and there is a better way to do it. In NVSHMEM's source, we handle it like this: https://github.com/NVIDIA/nvshmem/blob/2d7d25f0816235e3c2b51779571ec032606ea0dd/src/device/CMakeLists.txt#L188

virtual void free_tensor(uint64_t id) = 0;
// Register an externally-allocated symmetric tensor (e.g., via nvshmem4py)
// Returns a newly assigned tensor id
virtual uint64_t register_external_tensor(torch::Tensor& t) = 0;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the renaming I mentioned above.

throw std::runtime_error("Failed to allocate signal memory");
uint64_t* seq_num_signal = nullptr;
// TODO:
if (steps_inter_ > 0) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just here so my tests would pass on 1 node. If it's 1 node but we don't have this check, the calloc will fail because steps_inter_ is 0 so we allocate nothing.


void RecursiveLL8Coll::deregister_tensor(uint64_t id) {
// TODO: Implement
// TODO: Adding this so that I can test on 1-node. Is this valuable?
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if (!chunk_signal_) {
throw std::runtime_error("Failed to allocate chunk signal memory");
// TODO: Adding this so that I can test on 1-node. Is this valuable?
if (steps_inter_ > 0) {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

uid_gpu = uid_bytes.to(f"cuda:{local_device}")
dist.broadcast(uid_gpu, src=0)
# Initialize NVSHMEM via nvshmem4py using UID method
cuda_dev = Device(local_device)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same boilerplate.

@prajwal1210
Copy link
Collaborator

Oh, somehow I missed the notification for this PR last week. I will look over the comments and changes and respond to them as soon as possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants