Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1653293

Browse files
WoosukKwonRobert Shaw
authored and
Robert Shaw
committed
[Hardware][TPU] Support parallel sampling & Swapping (vllm-project#5855)
1 parent 5095252 commit 1653293

File tree

3 files changed

+147
-56
lines changed

3 files changed

+147
-56
lines changed

vllm/attention/backends/pallas.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,35 @@ def get_kv_cache_shape(
2828
) -> Tuple[int, ...]:
2929
return (num_kv_heads, num_blocks, block_size, head_size)
3030

31+
@torch.compile(backend="openxla")
3132
@staticmethod
3233
def swap_blocks(
33-
src_kv_cache: torch.Tensor,
34-
dst_kv_cache: torch.Tensor,
35-
src_to_dst: Dict[int, int],
34+
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
35+
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
36+
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
3637
) -> None:
37-
raise NotImplementedError("swap_blocks is not implemented.")
38+
src_k_cache, src_v_cache = src_kv_cache
39+
dst_k_cache, dst_v_cache = dst_kv_cache
40+
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
41+
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
3842

43+
device = dst_k_cache.device
44+
src_indices, dst_indices = src_to_dst
45+
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
46+
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
47+
48+
@torch.compile(backend="openxla")
3949
@staticmethod
4050
def copy_blocks(
41-
kv_caches: List[torch.Tensor],
42-
src_to_dists: Dict[int, List[int]],
51+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
52+
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
4353
) -> None:
44-
# TODO(woosuk): Implement this.
45-
raise NotImplementedError("copy_blocks is not implemented.")
54+
src_indices, dst_indices = src_to_dists
55+
for k_cache, v_cache in kv_caches:
56+
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
57+
k_cache[:, dst_indices] = k_cache[:, src_indices]
58+
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
59+
v_cache[:, dst_indices] = v_cache[:, src_indices]
4660

4761

4862
@dataclass

vllm/worker/tpu_model_runner.py

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
_PAD_SLOT_ID = 0 # FIXME(woosuk)
2323
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
2424
_ENABLE_TOP_P = False
25+
# FIXME(woosuk): A temporary hack to support `n > 1`.
26+
# This can significantly affect the performance if too large.
27+
_MAX_NUM_SAMPLES = 128
2528

2629

2730
class TPUModelRunner:
@@ -143,8 +146,9 @@ def _dummy_run(
143146
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
144147

145148
# Dummy run.
149+
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
146150
self.model(token_ids, position_ids, kv_caches, attn_metadata,
147-
input_lens, t, p)
151+
input_lens, t, p, num_samples)
148152

149153
def warmup_model(
150154
self,
@@ -268,14 +272,11 @@ def _prepare_decode(
268272
input_positions: List[List[int]] = []
269273
slot_mapping: List[List[int]] = []
270274
context_lens: List[int] = []
271-
num_seq_groups = len(seq_group_metadata_list)
272-
batch_size = _get_padded_batch_size(num_seq_groups)
273275

274-
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
276+
batch_idx = 0
277+
for seq_group_metadata in seq_group_metadata_list:
275278
assert not seq_group_metadata.is_prompt
276-
277279
seq_ids = list(seq_group_metadata.seq_data.keys())
278-
279280
for seq_id in seq_ids:
280281
seq_data = seq_group_metadata.seq_data[seq_id]
281282
generation_token = seq_data.get_last_token_id()
@@ -288,14 +289,16 @@ def _prepare_decode(
288289

289290
assert seq_group_metadata.block_tables is not None
290291
block_table = seq_group_metadata.block_tables[seq_id]
291-
self.block_tables[i, :len(block_table)] = block_table
292+
self.block_tables[batch_idx, :len(block_table)] = block_table
293+
batch_idx += 1
292294

293295
block_number = block_table[position // self.block_size]
294296
block_offset = position % self.block_size
295297
slot = block_number * self.block_size + block_offset
296298
slot_mapping.append([slot])
297299

298-
num_paddings = batch_size - num_seq_groups
300+
batch_size = _get_padded_batch_size(batch_idx)
301+
num_paddings = batch_size - batch_idx
299302
input_tokens = input_tokens + [[0]] * num_paddings
300303
input_positions = input_positions + [[0]] * num_paddings
301304
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
@@ -333,14 +336,13 @@ def _prepare_sample(
333336
self,
334337
seq_group_metadata_list: List[SequenceGroupMetadata],
335338
padded_batch_size: int,
336-
) -> Tuple[torch.Tensor, torch.Tensor]:
339+
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
337340
assert len(seq_group_metadata_list) > 0
338341
t = []
339342
p = []
343+
best_of = []
340344
for seq_group_metadata in seq_group_metadata_list:
341-
assert seq_group_metadata.sampling_params is not None
342345
sampling_params = seq_group_metadata.sampling_params
343-
344346
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
345347
# low temperature. This is not accurate.
346348
t.append(sampling_params.temperature
@@ -354,10 +356,11 @@ def _prepare_sample(
354356
raise NotImplementedError(
355357
"Top-k sampling is currently disabled for the TPU backend "
356358
"due to performance issues.")
357-
if sampling_params.best_of > 1:
359+
if sampling_params.best_of > _MAX_NUM_SAMPLES:
358360
raise NotImplementedError(
359-
"best_of > 1 is not currently supported by the TPU "
361+
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
360362
"backend.")
363+
best_of.append(sampling_params.best_of)
361364
if sampling_params.use_beam_search:
362365
raise NotImplementedError(
363366
"Beam search is not supported by the TPU backend.")
@@ -369,13 +372,19 @@ def _prepare_sample(
369372
"prompt_logprobs is not currently supported by the TPU "
370373
"backend.")
371374

372-
num_paddings = padded_batch_size - len(seq_group_metadata_list)
375+
# Repeat the sampling params if the seq group has multiple seqs.
376+
num_seqs = len(seq_group_metadata.seq_data)
377+
t += [t[-1]] * (num_seqs - 1)
378+
p += [p[-1]] * (num_seqs - 1)
379+
best_of += [best_of[-1]] * (num_seqs - 1)
380+
381+
num_paddings = padded_batch_size - len(t)
373382
t += [1.0] * num_paddings
374383
p += [1.0] * num_paddings
375384

376385
t = torch.tensor(t, dtype=torch.float32, device=self.device)
377386
p = torch.tensor(p, dtype=torch.float32, device=self.device)
378-
return t, p
387+
return t, p, best_of
379388

380389
def _execute_model(
381390
self,
@@ -392,28 +401,41 @@ def _execute_model(
392401
else:
393402
inputs = self._prepare_decode(seq_group_metadata_list)
394403
padded_batch_size = inputs[0].shape[0]
395-
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
404+
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
405+
padded_batch_size)
406+
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
396407

397408
# Execute the model.
398409
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
399-
*inputs[2:], t, p)
410+
*inputs[2:], t, p, num_samples)
400411
# Retrieve the outputs to CPU.
401412
next_token_ids = next_token_ids.cpu().tolist()
402413

403414
# NOTE(woosuk): Minimal code to construct the sampler outputs.
404415
# The TPU backend does not reuse the sampler, since the TPU backend
405416
# does not support the advanced sampling parameters such as logprobs.
406-
i = 0
417+
zero_logprob = Logprob(0.0)
418+
batch_idx = 0
407419
sampler_outputs = []
408420
for seq_group_metadata in seq_group_metadata_list:
409421
seq_outputs = []
410422
seq_ids = list(seq_group_metadata.seq_data.keys())
411-
for seq_id in seq_ids:
412-
next_token_id = next_token_ids[i]
413-
seq_outputs.append(
414-
SequenceOutput(seq_id, next_token_id,
415-
{next_token_id: Logprob(0.0)}))
416-
i += 1
423+
if is_prompt:
424+
assert len(seq_ids) == 1
425+
seq_id = seq_ids[0]
426+
for i in range(best_of[batch_idx]):
427+
next_token_id = next_token_ids[batch_idx][i]
428+
seq_outputs.append(
429+
SequenceOutput(seq_id, next_token_id,
430+
{next_token_id: zero_logprob}))
431+
batch_idx += 1
432+
else:
433+
for seq_id in seq_ids:
434+
next_token_id = next_token_ids[batch_idx][0]
435+
seq_outputs.append(
436+
SequenceOutput(seq_id, next_token_id,
437+
{next_token_id: zero_logprob}))
438+
batch_idx += 1
417439
sampler_outputs.append(
418440
CompletionSequenceGroupOutput(seq_outputs, None))
419441
return sampler_outputs
@@ -458,6 +480,7 @@ def forward(
458480
input_lens: torch.Tensor,
459481
t: torch.Tensor,
460482
p: torch.Tensor,
483+
num_samples: int,
461484
) -> torch.Tensor:
462485
"""Executes the forward pass of the model and samples the next token.
463486
@@ -520,8 +543,9 @@ def forward(
520543
if _ENABLE_TOP_P:
521544
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
522545
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
523-
# FIXME(woosuk): best_of > 1 is not supported.
524-
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
546+
next_token_ids = torch.multinomial(probs,
547+
num_samples,
548+
replacement=True)
525549
return next_token_ids
526550

527551

vllm/worker/tpu_worker.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import List, Optional, Tuple
2+
from typing import List, Optional, Tuple, Union
33

44
import torch
55
import torch_xla.core.xla_model as xm
@@ -117,19 +117,26 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
117117
# Synchronize before measuring the memory usage.
118118
xm.wait_device_ops()
119119

120+
dtype_btyes = get_dtype_size(self.cache_dtype)
121+
block_size = self.cache_config.block_size
122+
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
123+
head_size * num_kv_heads)
124+
125+
# Calculate the TPU KV cache size based on profiling.
120126
m = xm.get_memory_info(self.device)
121127
total_memory_size = m["bytes_limit"]
122128
usable_memory_size = int(total_memory_size *
123129
self.cache_config.gpu_memory_utilization)
124130
profiled = m["bytes_used"] # Weights + intermediate activations.
125-
kv_cache_bytes = max(usable_memory_size - profiled, 0)
126-
dtype_btyes = get_dtype_size(self.cache_dtype)
127-
block_size = self.cache_config.block_size
128-
num_tpu_blocks = (kv_cache_bytes //
129-
(dtype_btyes * block_size * num_layers * 2 *
130-
head_size * num_kv_heads))
131+
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
132+
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
131133
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
132-
return num_tpu_blocks, 0
134+
135+
# Calculate the CPU KV cache size based on the config.
136+
num_cpu_blocks = (self.cache_config.swap_space_bytes //
137+
block_size_bytes)
138+
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
139+
return num_tpu_blocks, num_cpu_blocks
133140

134141
def initialize_cache(
135142
self,
@@ -145,15 +152,19 @@ def initialize_cache(
145152
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
146153
head_size = self.model_config.get_head_size()
147154

155+
self.cpu_cache = []
148156
self.tpu_cache = []
149157
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
150158
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
151159
for _ in range(num_layers):
152-
key_cache = torch.zeros(tpu_cache_shape,
153-
dtype=dtype,
154-
device=self.device)
155-
value_cache = torch.zeros_like(key_cache)
156-
self.tpu_cache.append((key_cache, value_cache))
160+
tpu_k_cache = torch.zeros(tpu_cache_shape,
161+
dtype=dtype,
162+
device=self.device)
163+
tpu_v_cache = torch.zeros_like(tpu_k_cache)
164+
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
165+
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
166+
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
167+
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
157168
self._warmup_model()
158169

159170
def _warmup_model(self) -> None:
@@ -187,26 +198,68 @@ def execute_model(
187198
if not self.is_driver_worker:
188199
self._execute_model_non_driver()
189200
return []
190-
191201
assert execute_model_req is not None
192-
# Currently, TPUWorker does not support swapping.
193-
# TODO(woosuk): Support block copying.
194-
assert len(execute_model_req.blocks_to_swap_in) == 0, (
195-
"Swapping is not supported for the TPU backend.")
196-
assert len(execute_model_req.blocks_to_swap_out) == 0, (
197-
"Swapping is not supported for the TPU backend.")
198-
assert len(execute_model_req.blocks_to_copy) == 0
199-
202+
# Issue cache operations.
203+
self.cache_swap(
204+
execute_model_req.blocks_to_swap_in,
205+
execute_model_req.blocks_to_swap_out,
206+
execute_model_req.blocks_to_copy,
207+
)
208+
# Run the model.
200209
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
201210
assert len(seq_group_metadata_list) > 0
202211
output = self.model_runner.execute_model(seq_group_metadata_list,
203212
self.tpu_cache)
204213
return [output]
205214

215+
def cache_swap(
216+
self,
217+
blocks_to_swap_in: List[Tuple[int, int]],
218+
blocks_to_swap_out: List[Tuple[int, int]],
219+
blocks_to_copy: List[Tuple[int, int]],
220+
) -> None:
221+
attn_backend = self.model_runner.attn_backend
222+
num_layers = self.model_config.get_num_layers(self.parallel_config)
223+
224+
if blocks_to_swap_in:
225+
# Swap from CPU to TPU.
226+
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
227+
self.device)
228+
for i in range(num_layers):
229+
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
230+
src_to_dst)
231+
if blocks_to_swap_out:
232+
# Swap from TPU to CPU.
233+
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
234+
"cpu")
235+
for i in range(num_layers):
236+
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
237+
src_to_dst)
238+
if blocks_to_copy:
239+
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
240+
self.device)
241+
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
242+
206243
def start_worker_execution_loop(self) -> None:
207244
while self._execute_model_non_driver():
208245
pass
209246

210247
def _execute_model_non_driver(self) -> bool:
211248
self.model_runner.execute_model(None, self.tpu_cache)
212249
return True
250+
251+
252+
def _make_src_to_dst(
253+
mapping: List[Tuple[int, int]],
254+
src_device: Union[torch.device, str],
255+
dst_device: Union[torch.device, str],
256+
) -> Tuple[torch.Tensor, torch.Tensor]:
257+
src_indices = [i for i, _ in mapping]
258+
dst_indices = [i for _, i in mapping]
259+
src_indices = torch.tensor(src_indices,
260+
device=src_device,
261+
dtype=torch.int64)
262+
dst_indices = torch.tensor(dst_indices,
263+
device=dst_device,
264+
dtype=torch.int64)
265+
return src_indices, dst_indices

0 commit comments

Comments
 (0)