Skip to content

Serialize using safetensors for KV caches #14228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import json
import os
import pickle
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union

import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save

from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
Expand Down Expand Up @@ -237,14 +238,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int:
return hash(tensor.data_ptr())

def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic."""
value_bytes = pickle.dumps(tensor)
self.transfer_engine.send_bytes(value_bytes)
"""Implement the tensor sending logic using safetensors."""
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))

def _recv_impl(self) -> torch.Tensor:
"""Implement the tensor receiving logic."""
"""Implement the tensor receiving logic using safetensors."""
data = self.transfer_engine.recv_bytes()
return pickle.loads(data)
return safetensors_load(data)["tensor"].to(self.device)

def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send tensor to the target process."""
Expand Down