Skip to content

[Hardware][TPU] Support parallel sampling & Swapping #5855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 26, 2024
30 changes: 22 additions & 8 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,35 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)

@torch.compile(backend="openxla")
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
) -> None:
raise NotImplementedError("swap_blocks is not implemented.")
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)

device = dst_k_cache.device
src_indices, dst_indices = src_to_dst
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)

@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
# TODO(woosuk): Implement this.
raise NotImplementedError("copy_blocks is not implemented.")
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]


@dataclass
Expand Down
76 changes: 50 additions & 26 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
_PAD_SLOT_ID = 0 # FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False
# FIXME(woosuk): A temporary hack to support `n > 1`.
# This can significantly affect the performance if too large.
_MAX_NUM_SAMPLES = 128


class TPUModelRunner:
Expand Down Expand Up @@ -143,8 +146,9 @@ def _dummy_run(
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)

# Dummy run.
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
self.model(token_ids, position_ids, kv_caches, attn_metadata,
input_lens, t, p)
input_lens, t, p, num_samples)

def warmup_model(
self,
Expand Down Expand Up @@ -268,14 +272,11 @@ def _prepare_decode(
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
num_seq_groups = len(seq_group_metadata_list)
batch_size = _get_padded_batch_size(num_seq_groups)

for i, seq_group_metadata in enumerate(seq_group_metadata_list):
batch_idx = 0
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt

seq_ids = list(seq_group_metadata.seq_data.keys())

for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
Expand All @@ -288,14 +289,16 @@ def _prepare_decode(

assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
self.block_tables[i, :len(block_table)] = block_table
self.block_tables[batch_idx, :len(block_table)] = block_table
batch_idx += 1

block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])

num_paddings = batch_size - num_seq_groups
batch_size = _get_padded_batch_size(batch_idx)
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
Expand Down Expand Up @@ -333,14 +336,13 @@ def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
padded_batch_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
assert len(seq_group_metadata_list) > 0
t = []
p = []
best_of = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.sampling_params is not None
sampling_params = seq_group_metadata.sampling_params

# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
Expand All @@ -354,10 +356,11 @@ def _prepare_sample(
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.best_of > 1:
if sampling_params.best_of > _MAX_NUM_SAMPLES:
raise NotImplementedError(
"best_of > 1 is not currently supported by the TPU "
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
Expand All @@ -369,13 +372,19 @@ def _prepare_sample(
"prompt_logprobs is not currently supported by the TPU "
"backend.")

num_paddings = padded_batch_size - len(seq_group_metadata_list)
# Repeat the sampling params if the seq group has multiple seqs.
num_seqs = len(seq_group_metadata.seq_data)
t += [t[-1]] * (num_seqs - 1)
p += [p[-1]] * (num_seqs - 1)
best_of += [best_of[-1]] * (num_seqs - 1)

num_paddings = padded_batch_size - len(t)
t += [1.0] * num_paddings
p += [1.0] * num_paddings

t = torch.tensor(t, dtype=torch.float32, device=self.device)
p = torch.tensor(p, dtype=torch.float32, device=self.device)
return t, p
return t, p, best_of

def _execute_model(
self,
Expand All @@ -392,28 +401,41 @@ def _execute_model(
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
t, p, best_of = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1

# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:], t, p)
*inputs[2:], t, p, num_samples)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()

# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i = 0
zero_logprob = Logprob(0.0)
batch_idx = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
seq_outputs = []
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
next_token_id = next_token_ids[i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: Logprob(0.0)}))
i += 1
if is_prompt:
assert len(seq_ids) == 1
seq_id = seq_ids[0]
for i in range(best_of[batch_idx]):
next_token_id = next_token_ids[batch_idx][i]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
else:
for seq_id in seq_ids:
next_token_id = next_token_ids[batch_idx][0]
seq_outputs.append(
SequenceOutput(seq_id, next_token_id,
{next_token_id: zero_logprob}))
batch_idx += 1
sampler_outputs.append(
CompletionSequenceGroupOutput(seq_outputs, None))
return sampler_outputs
Expand Down Expand Up @@ -458,6 +480,7 @@ def forward(
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.

Expand Down Expand Up @@ -520,8 +543,9 @@ def forward(
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
next_token_ids = torch.multinomial(probs,
num_samples,
replacement=True)
return next_token_ids


Expand Down
97 changes: 75 additions & 22 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -117,19 +117,26 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# Synchronize before measuring the memory usage.
xm.wait_device_ops()

dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads)

# Calculate the TPU KV cache size based on profiling.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
usable_memory_size = int(total_memory_size *
self.cache_config.gpu_memory_utilization)
profiled = m["bytes_used"] # Weights + intermediate activations.
kv_cache_bytes = max(usable_memory_size - profiled, 0)
dtype_btyes = get_dtype_size(self.cache_dtype)
block_size = self.cache_config.block_size
num_tpu_blocks = (kv_cache_bytes //
(dtype_btyes * block_size * num_layers * 2 *
head_size * num_kv_heads))
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, 0

# Calculate the CPU KV cache size based on the config.
num_cpu_blocks = (self.cache_config.swap_space_bytes //
block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks

def initialize_cache(
self,
Expand All @@ -145,15 +152,19 @@ def initialize_cache(
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()

self.cpu_cache = []
self.tpu_cache = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
for _ in range(num_layers):
key_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
value_cache = torch.zeros_like(key_cache)
self.tpu_cache.append((key_cache, value_cache))
tpu_k_cache = torch.zeros(tpu_cache_shape,
dtype=dtype,
device=self.device)
tpu_v_cache = torch.zeros_like(tpu_k_cache)
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
self._warmup_model()

def _warmup_model(self) -> None:
Expand Down Expand Up @@ -187,26 +198,68 @@ def execute_model(
if not self.is_driver_worker:
self._execute_model_non_driver()
return []

assert execute_model_req is not None
# Currently, TPUWorker does not support swapping.
# TODO(woosuk): Support block copying.
assert len(execute_model_req.blocks_to_swap_in) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_swap_out) == 0, (
"Swapping is not supported for the TPU backend.")
assert len(execute_model_req.blocks_to_copy) == 0

# Issue cache operations.
self.cache_swap(
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy,
)
# Run the model.
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
assert len(seq_group_metadata_list) > 0
output = self.model_runner.execute_model(seq_group_metadata_list,
self.tpu_cache)
return [output]

def cache_swap(
self,
blocks_to_swap_in: List[Tuple[int, int]],
blocks_to_swap_out: List[Tuple[int, int]],
blocks_to_copy: List[Tuple[int, int]],
) -> None:
attn_backend = self.model_runner.attn_backend
num_layers = self.model_config.get_num_layers(self.parallel_config)

if blocks_to_swap_in:
# Swap from CPU to TPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
self.device)
for i in range(num_layers):
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
src_to_dst)
if blocks_to_swap_out:
# Swap from TPU to CPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
"cpu")
for i in range(num_layers):
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
src_to_dst)
if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
attn_backend.copy_blocks(self.tpu_cache, src_to_dst)

def start_worker_execution_loop(self) -> None:
while self._execute_model_non_driver():
pass

def _execute_model_non_driver(self) -> bool:
self.model_runner.execute_model(None, self.tpu_cache)
return True


def _make_src_to_dst(
mapping: List[Tuple[int, int]],
src_device: Union[torch.device, str],
dst_device: Union[torch.device, str],
) -> Tuple[torch.Tensor, torch.Tensor]:
src_indices = [i for i, _ in mapping]
dst_indices = [i for _, i in mapping]
src_indices = torch.tensor(src_indices,
device=src_device,
dtype=torch.int64)
dst_indices = torch.tensor(dst_indices,
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices
Loading