Skip to content

[Core] Enable IPv6 with vllm.utils.make_zmq_socket() #16506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import pytest
import torch
import zmq
from vllm_test_utils.monitor import monitor

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, sha256,
make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)

from .utils import create_new_process_for_each_test, error_on_warning
Expand Down Expand Up @@ -662,3 +664,53 @@ def test_sha256(input: tuple, output: int):

# hashing different input, returns different value
assert hash != sha256(input + (1, ))


@pytest.mark.parametrize(
"path,expected",
[
("ipc://some_path", ("ipc", "some_path", "")),
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected


@pytest.mark.parametrize(
"invalid_path",
[
"invalid_path", # Missing scheme
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)


def test_make_zmq_socket_ipv6():
# Check if IPv6 is supported by trying to create an IPv6 socket
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.close()
except socket.error:
pytest.skip("IPv6 is not supported on this system")

ctx = zmq.Context()
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
socket_type = zmq.REP # Example socket type

# Create the socket
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)

# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"

# Clean up
zsock.close()
ctx.term()
28 changes: 28 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
overload)
from urllib.parse import urlparse
from uuid import uuid4

import cachetools
Expand Down Expand Up @@ -2278,6 +2279,27 @@ def get_exception_traceback():
return err_str


def split_zmq_path(path: str) -> Tuple[str, str, str]:
"""Split a zmq path into its parts."""
parsed = urlparse(path)
if not parsed.scheme:
raise ValueError(f"Invalid zmq path: {path}")

scheme = parsed.scheme
host = parsed.hostname or ""
port = str(parsed.port or "")

if scheme == "tcp" and not all((host, port)):
# The host and port fields are required for tcp
raise ValueError(f"Invalid zmq path: {path}")

if scheme != "tcp" and port:
# port only makes sense with tcp
raise ValueError(f"Invalid zmq path: {path}")

return scheme, host, port


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
Expand Down Expand Up @@ -2317,6 +2339,12 @@ def make_zmq_socket(
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)

# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)
if scheme == "tcp" and is_valid_ipv6_address(host):
socket.setsockopt(zmq.IPV6, 1)

if bind:
socket.bind(path)
else:
Expand Down