|
1 | 1 | import json
|
2 | 2 | import os
|
3 |
| -import pickle |
4 | 3 | from concurrent.futures import ThreadPoolExecutor
|
5 | 4 | from dataclasses import dataclass
|
6 | 5 | from typing import Optional, Union
|
7 | 6 |
|
8 | 7 | import torch
|
9 | 8 | import zmq
|
| 9 | +from safetensors.torch import load as safetensors_load |
| 10 | +from safetensors.torch import save as safetensors_save |
10 | 11 |
|
11 | 12 | from vllm.config import KVTransferConfig
|
12 | 13 | from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
@@ -235,14 +236,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int:
|
235 | 236 | return hash(tensor.data_ptr())
|
236 | 237 |
|
237 | 238 | def _send_impl(self, tensor: torch.Tensor) -> None:
|
238 |
| - """Implement the tensor sending logic.""" |
239 |
| - value_bytes = pickle.dumps(tensor) |
240 |
| - self.transfer_engine.send_bytes(value_bytes) |
| 239 | + """Implement the tensor sending logic using safetensors.""" |
| 240 | + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) |
241 | 241 |
|
242 | 242 | def _recv_impl(self) -> torch.Tensor:
|
243 |
| - """Implement the tensor receiving logic.""" |
| 243 | + """Implement the tensor receiving logic using safetensors.""" |
244 | 244 | data = self.transfer_engine.recv_bytes()
|
245 |
| - return pickle.loads(data) |
| 245 | + return safetensors_load(data)["tensor"].to(self.device) |
246 | 246 |
|
247 | 247 | def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
248 | 248 | """Send tensor to the target process."""
|
|
0 commit comments