Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/features/disagg_prefill.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Now supports 5 types of connectors:
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
```

For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:

```bash
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_buffer_device":"cuda", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'
```

- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker):

```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/serving/expert_parallel_deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ For production deployments requiring strict SLA guarantees for time-to-first-tok

1. **Install gdrcopy/ucx/nixl**: For maximum performance, run the [install_gdrcopy.sh](gh-file:tools/install_gdrcopy.sh) script to install `gdrcopy` (e.g., `install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "12.8" "x64"`). You can find available OS versions [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2012.8/). If `gdrcopy` is not installed, things will still work with a plain `pip install nixl`, just with lower performance. `nixl` and `ucx` are installed as dependencies via pip.

2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`
2. **Configure Both Instances**: Add this flag to both prefill and decode instances `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}`. Noted, you may also specify one or multiple NIXL_Backend. Such as: `--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both", "kv_connector_extra_config":{"backend":["UCX", "GDS"]}'`

3. **Client Orchestration**: Use the client-side script below to coordinate prefill/decode operations. We are actively working on routing solutions.

Expand Down
52 changes: 51 additions & 1 deletion tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker, NixlKVConnectorStats)
from vllm.forward_context import ForwardContext
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, agent_name: str, *args, **kwargs):
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]

def register_memory(self, descs) -> None:
def register_memory(self, descs, backends) -> None:
pass

def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
Expand Down Expand Up @@ -855,3 +856,52 @@ def test_register_kv_caches(dist_init):
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"


class FakePlatform(Platform):
device_type: str = "oot"

@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {'oot': ('oot', )}

@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return 'VRAM'


@pytest.mark.parametrize("kv_buffer_device, nixl_memory_type", [
("oot", "VRAM"),
])
def test_kv_buffer_to_nixl_memory_types(dist_init, kv_buffer_device,
nixl_memory_type):
"""
Test that register_kv_caches() passes the correct memory types from the
config to the nixl_wrapper.
"""
vllm_config = create_vllm_config()
# Override the default memory types in the config
vllm_config.kv_transfer_config.kv_buffer_device = kv_buffer_device
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
_NIXL_SUPPORTED_DEVICE)
_NIXL_SUPPORTED_DEVICE.update(FakePlatform.get_nixl_supported_devices())

with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.current_platform", FakePlatform), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector._NIXL_SUPPORTED_DEVICE", _NIXL_SUPPORTED_DEVICE): # noqa: E501

# Create connector and replace its worker with a fake one for isolation
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)

# Verify get_reg_descs was called with the correct memory_type
assert connector.connector_worker.kv_buffer_device == kv_buffer_device
assert connector.connector_worker.nixl_memory_type == nixl_memory_type
33 changes: 26 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,21 @@
logger.warning("NIXL is not available")
NixlWrapper = None

try:
from nixl._api import nixl_agent_config
except ImportError:
nixl_agent_config = None
logger.warning("NIXL agent config is not available")

# Supported platforms and types of kv transfer buffer.
# {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_DEVICE = {
"cuda": ("cuda", ),
"tpu": ("cpu", ),
"xpu": ("cpu", ),
}
# support for oot platform by providing mapping in current_platform
_NIXL_SUPPORTED_DEVICE.update(current_platform.get_nixl_supported_devices())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable _NIXL_SUPPORTED_DEVICE appears to be unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class NixlAgentMetadata(
Expand Down Expand Up @@ -448,8 +456,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

self.nixl_backends = \
vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"])
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 and nixl_agent_config is not None else None

self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)

Expand Down Expand Up @@ -486,11 +501,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
else:
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported.")
Expand Down Expand Up @@ -766,7 +785,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
self.nixl_wrapper.register_memory(descs, backends=self.nixl_backends)
logger.debug("Done registering descs")
self._registered_descs.append(descs)

Expand Down
15 changes: 15 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,21 @@ def _synced_weight_loader(param, *args, **kwargs):

return _synced_weight_loader

@classmethod
def get_nixl_supported_devices(cls) -> dict[str, tuple[str, ...]]:
"""
Returns a mapping from device_type to a tuple of supported
kv_buffer_device for nixl.
"""
return {}

@classmethod
def get_nixl_memory_type(cls) -> Optional[str]:
"""
Returns the nixl memory type for the current platform.
"""
return None


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down