Skip to content

Commit 94a5e77

Browse files
KuntaiDuliu-shaojun
authored andcommitted
[Security] Serialize using safetensors instead of pickle in Mooncake Pipe (vllm-project#14228)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
1 parent 2417e62 commit 94a5e77

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
22
import os
3-
import pickle
43
from concurrent.futures import ThreadPoolExecutor
54
from dataclasses import dataclass
65
from typing import Optional, Union
76

87
import torch
98
import zmq
9+
from safetensors.torch import load as safetensors_load
10+
from safetensors.torch import save as safetensors_save
1011

1112
from vllm.config import KVTransferConfig
1213
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
@@ -235,14 +236,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int:
235236
return hash(tensor.data_ptr())
236237

237238
def _send_impl(self, tensor: torch.Tensor) -> None:
238-
"""Implement the tensor sending logic."""
239-
value_bytes = pickle.dumps(tensor)
240-
self.transfer_engine.send_bytes(value_bytes)
239+
"""Implement the tensor sending logic using safetensors."""
240+
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
241241

242242
def _recv_impl(self) -> torch.Tensor:
243-
"""Implement the tensor receiving logic."""
243+
"""Implement the tensor receiving logic using safetensors."""
244244
data = self.transfer_engine.recv_bytes()
245-
return pickle.loads(data)
245+
return safetensors_load(data)["tensor"].to(self.device)
246246

247247
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
248248
"""Send tensor to the target process."""

0 commit comments

Comments
 (0)