Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 209a147

Browse files
youkaichaoRobert Shaw
authored andcommitted
[core][misc] remove logical block (vllm-project#5882)
1 parent c1d4964 commit 209a147

File tree

3 files changed

+16
-120
lines changed

3 files changed

+16
-120
lines changed

vllm/block.py

Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,10 @@
11
"""Token blocks."""
2-
import weakref
3-
from collections import defaultdict
4-
from typing import Dict, List
2+
from typing import List
53

64
from vllm.utils import Device
75

8-
_BLANK_TOKEN_ID = -1
9-
106
DEFAULT_LAST_ACCESSED_TIME = -1
117

12-
TokensBlock = List[int]
13-
14-
15-
class BlockPool:
16-
"""A pool of logical blocks.
17-
When requests come, we create a lot of logical blocks;
18-
when requests are done, we destroy a lot of logical blocks.
19-
It turns out that creating and destroying logical blocks can be expensive,
20-
especially for the `token_ids` field, which is a list of integers.
21-
To avoid this overhead, we use a pool to manage the logical blocks.
22-
When an old request is done and a new request comes, we can reuse the
23-
logical blocks from the old request to feed the new request.
24-
"""
25-
26-
def __init__(self) -> None:
27-
# block size to list of token blocks
28-
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
29-
30-
def alloc_block(self, block_size: int) -> TokensBlock:
31-
if block_size in self.pool and self.pool[block_size]:
32-
return self.pool[block_size].pop()
33-
return [_BLANK_TOKEN_ID] * block_size
34-
35-
def del_block(self, block: TokensBlock) -> None:
36-
self.pool[len(block)].append(block)
37-
38-
39-
_BLOCK_POOL = BlockPool()
40-
41-
42-
class LogicalTokenBlock:
43-
"""A block that stores a contiguous chunk of tokens from left to right.
44-
45-
Logical blocks are used to represent the states of the corresponding
46-
physical blocks in the KV cache.
47-
"""
48-
49-
def __init__(
50-
self,
51-
block_number: int,
52-
block_size: int,
53-
) -> None:
54-
self.block_number = block_number
55-
self.block_size = block_size
56-
57-
self.token_ids = _BLOCK_POOL.alloc_block(block_size)
58-
# this finalizer is used to return the block to the pool when the object is deleted # noqa
59-
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
60-
# i.e. `self.token_ids` may be deleted before `self`, and we lose
61-
# the opportunity to return the block to the pool
62-
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
63-
self.token_ids)
64-
self.num_tokens = 0
65-
66-
def is_empty(self) -> bool:
67-
return self.num_tokens == 0
68-
69-
def get_num_empty_slots(self) -> int:
70-
return self.block_size - self.num_tokens
71-
72-
def is_full(self) -> bool:
73-
return self.num_tokens == self.block_size
74-
75-
def append_tokens(self, token_ids: List[int]) -> None:
76-
assert len(token_ids) <= self.get_num_empty_slots()
77-
curr_idx = self.num_tokens
78-
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
79-
self.num_tokens += len(token_ids)
80-
81-
def get_token_ids(self) -> List[int]:
82-
return self.token_ids[:self.num_tokens]
83-
84-
def get_last_token_id(self) -> int:
85-
assert self.num_tokens > 0
86-
return self.token_ids[self.num_tokens - 1]
87-
888

899
class PhysicalTokenBlock:
9010
"""Represents the state of a block in the KV cache."""

vllm/core/block_manager_v1.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ def __init__(
262262
self.cross_block_tables: Dict[str, BlockTable] = {}
263263

264264
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
265-
return 0 if seq is None \
266-
else len(seq.logical_token_blocks)
265+
return 0 if seq is None else seq.n_blocks
267266

268267
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
269268
# FIXME(woosuk): Here we assume that all sequences in the group share
@@ -298,7 +297,7 @@ def _allocate_sequence(self, \
298297
ref_count: int, \
299298
is_encoder_decoder: bool = True) -> BlockTable:
300299
# Allocate new physical token blocks that will store the prompt tokens.
301-
num_prompt_blocks = len(seq.logical_token_blocks)
300+
num_prompt_blocks = seq.n_blocks
302301

303302
block_table: BlockTable = []
304303
for logical_idx in range(num_prompt_blocks):
@@ -367,7 +366,7 @@ def _promote_last_block(
367366

368367
# Compute a new hash for the block so that it can be shared by other
369368
# Sequences
370-
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
369+
new_hash = seq.hash_of_block(seq.n_blocks - 1)
371370

372371
# if new_hash is already in the cached table, then free last_block
373372
# and return the cached version
@@ -407,10 +406,10 @@ def _allocate_last_physical_block(
407406
if not self.enable_caching:
408407
return self.gpu_allocator.allocate()
409408
block_hash: Optional[int] = None
409+
n_blocks = seq.n_blocks
410410
if (self._is_last_block_full(seq)):
411-
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
412-
num_hashed_tokens = seq.num_hashed_tokens_of_block(
413-
len(seq.logical_token_blocks) - 1)
411+
block_hash = seq.hash_of_block(n_blocks - 1)
412+
num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1)
414413

415414
# num_hashed_tokens is used to compute future hashes
416415
# (e.g. in the hashing function, it is used to ask the sequence for
@@ -429,12 +428,12 @@ def append_slots(
429428
num_lookahead_slots: int = 0,
430429
) -> List[Tuple[int, int]]:
431430
"""Allocate a physical slot for a new token."""
432-
logical_blocks = seq.logical_token_blocks
431+
n_blocks = seq.n_blocks
433432
block_table = self.block_tables[seq.seq_id]
434433
# If we need to allocate a new physical block
435-
if len(block_table) < len(logical_blocks):
434+
if len(block_table) < n_blocks:
436435
# Currently this code only supports adding one physical block
437-
assert len(block_table) == len(logical_blocks) - 1
436+
assert len(block_table) == n_blocks - 1
438437

439438
if (self.block_sliding_window
440439
and len(block_table) >= self.block_sliding_window):

vllm/sequence.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Sequence and its related classes."""
22
import copy
33
import enum
4+
import math
45
from abc import ABC, abstractmethod
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
78

89
import torch
910

10-
from vllm.block import LogicalTokenBlock
1111
from vllm.inputs import LLMInputs
1212
from vllm.lora.request import LoRARequest
1313
from vllm.pooling_params import PoolingParams
@@ -236,9 +236,6 @@ def __init__(
236236
self.output_logprobs: SampleLogprobs = []
237237
self.output_text = ""
238238

239-
self.logical_token_blocks: List[LogicalTokenBlock] = []
240-
# Initialize the logical token blocks with the prompt token ids.
241-
self._append_tokens_to_blocks(self.prompt_token_ids)
242239
self.status = SequenceStatus.WAITING
243240
self.stop_reason: Union[int, str, None] = None
244241

@@ -248,6 +245,10 @@ def __init__(
248245
# Input + output tokens
249246
self.tokens: Optional[List[str]] = None
250247

248+
@property
249+
def n_blocks(self) -> int:
250+
return math.ceil(self.get_len() / self.block_size)
251+
251252
@property
252253
def prompt(self) -> Optional[str]:
253254
return self.inputs.get("prompt")
@@ -287,36 +288,12 @@ def reset_state_for_recompute(self):
287288
"""Reset the sequence states for recomputation."""
288289
self.data.reset_state_for_recompute()
289290

290-
def _append_logical_block(self) -> None:
291-
block = LogicalTokenBlock(
292-
block_number=len(self.logical_token_blocks),
293-
block_size=self.block_size,
294-
)
295-
self.logical_token_blocks.append(block)
296-
297-
def _append_tokens_to_blocks(self, token_ids: List[int]) -> None:
298-
cursor = 0
299-
while cursor < len(token_ids):
300-
if not self.logical_token_blocks:
301-
self._append_logical_block()
302-
303-
last_block = self.logical_token_blocks[-1]
304-
if last_block.is_full():
305-
self._append_logical_block()
306-
last_block = self.logical_token_blocks[-1]
307-
308-
num_empty_slots = last_block.get_num_empty_slots()
309-
last_block.append_tokens(token_ids[cursor:cursor +
310-
num_empty_slots])
311-
cursor += num_empty_slots
312-
313291
def append_token_id(
314292
self,
315293
token_id: int,
316294
logprobs: Dict[int, Logprob],
317295
) -> None:
318296
assert token_id in logprobs
319-
self._append_tokens_to_blocks([token_id])
320297
self.output_logprobs.append(logprobs)
321298
self.data.append_token_id(token_id, logprobs[token_id].logprob)
322299

@@ -388,7 +365,7 @@ def is_prefill(self) -> bool:
388365
def __repr__(self) -> str:
389366
return (f"Sequence(seq_id={self.seq_id}, "
390367
f"status={self.status.name}, "
391-
f"num_blocks={len(self.logical_token_blocks)})")
368+
f"num_blocks={self.n_blocks}, ")
392369

393370

394371
@dataclass

0 commit comments

Comments
 (0)