Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import json
import os
import random
import time
from typing import Optional

Expand All @@ -29,7 +30,6 @@
from transformers.generation.continuous_batching.requests import logger


# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
SLIDING_WINDOW = 0
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "meta-llama/Meta-Llama-3-8B"
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
Expand Down Expand Up @@ -193,6 +193,8 @@ def batch_generate(
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")

parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
parser.add_argument("--add-prefix", action="store_true", help="Add a prefix to the samples")

parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--output-file", type=str, default=None)
Expand Down Expand Up @@ -242,7 +244,18 @@ def batch_generate(
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
dataset = dataset.select(range(args.samples))

simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
def random_prefix() -> str:
if not args.add_prefix:
return ""
prefixes = [
"Math and reasonning problems are very important to the world. This is a problem, and then you will find the answer.\n",
"We all know that reasonning can be taught by answering questions, often illustrated with examples. Here is one and its solution, hopefully you will enjoy it!\n",
"Reasonning a very good metric of intelligence, hence it is regularly trained and tested in both children and AI model like LLMs. This test can look like a math or a logical problem, a riddle or pattern detection task. For instance, this is one of those test. You will find it and the solution associated after. Here it goes:\n",
] # fmt: skip
return random.choice(prefixes)

random.seed(0)
simple_batch_inputs = [tokenizer(random_prefix() + item["question"])["input_ids"] for item in dataset]

# Prepare generation config
generation_config = GenerationConfig(
Expand Down
78 changes: 61 additions & 17 deletions src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
from math import floor, gcd, sqrt
from typing import Optional

Expand All @@ -21,8 +20,8 @@
from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...utils.metrics import attach_tracer, traced
from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import get_device_and_memory_breakdown, logger
from .cache_manager import BlockManager, CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
from .requests import RequestState, get_device_and_memory_breakdown, logger


def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]], list[str]]:
Expand All @@ -32,7 +31,7 @@ def group_layers_by_attn_type(config: PreTrainedConfig) -> tuple[list[list[int]]
- All groups have the same number of layers

For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
We would get four groups: [0, 3], [1, 2], [4,5] and [6,7].
"""
# If the config has no layer_type attribute, it means all layers are the same attention type
layer_types = getattr(config, "layer_types", None)
Expand Down Expand Up @@ -173,10 +172,12 @@ def __init__(
page_size = self.head_dim * self.num_key_value_heads

if "flash" in self.config._attn_implementation:
num_attention_masks = 1 # only used to compute the default meme args
else:
num_attention_masks = 0 # only used to compute the default memory footprint args
elif "sliding_attention" in group_types:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
num_attention_masks = 2 if "sliding_attention" in group_types else 1
num_attention_masks = 2
else:
num_attention_masks = 1

memory_handler = PagedAttentionMemoryHandler(
block_size=self.block_size,
Expand Down Expand Up @@ -216,7 +217,6 @@ def __init__(
logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")

# Block management data structures
self._free_blocks = deque(range(num_blocks))
self.group_cache_managers: list[CacheAllocator] = []
for i, group_type in enumerate(group_types):
if group_type == "full_attention":
Expand All @@ -227,13 +227,18 @@ def __init__(
raise ValueError(f"Invalid group type: {group_type}")
self.group_cache_managers.append(cm)

# We only use prefix sharing if the whole model has only full attention layers
self.use_prefix_sharing = (group_types == ["full_attention"])
self._block_manager = BlockManager(num_blocks, self.block_size, self.use_prefix_sharing)
self.blocks_to_complete: dict[str, int] = {}

@traced
def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
def allocate_blocks(self, n_blocks: int, state: RequestState) -> int:
"""Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
max_allocated = 0
for cm in self.group_cache_managers:
allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
allocated = cm.allocate_blocks(n_blocks, state.request_id, self._block_manager)
if allocated is None:
return None
max_allocated = max(max_allocated, allocated)
Expand All @@ -244,11 +249,11 @@ def free_blocks(self, request_id: str) -> None:
"""Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
by the cache managers."""
for cm in self.group_cache_managers:
cm.free_blocks(request_id, self._free_blocks)
cm.free_blocks(request_id, self._block_manager)

def get_num_free_blocks(self) -> int:
"""Get the current number of unallocated blocks available for new requests."""
return len(self._free_blocks)
return self._block_manager.num_free_blocks

@traced
def extend_read_indices(
Expand Down Expand Up @@ -335,6 +340,38 @@ def update(
# Return the new KV values
return key_states_with_cache, value_states_with_cache

def search_prefix_match(self, request_id: str, prompt_ids: list[int]) -> int:
current_hash = None
allocated_blocks = []
for b in range(len(prompt_ids) // self.block_size):
tokens = prompt_ids[b * self.block_size : (b + 1) * self.block_size]
current_hash = self._block_manager.compute_hash(current_hash, tokens)
block_id = self._block_manager._hash_to_id.get(current_hash)
if block_id is not None:
allocated_blocks.append(block_id)
self._block_manager.increase_ref_count(block_id)
else:
break
# If we found a matching prefix, we reference the blocks in the request
if allocated_blocks:
logger.debug(f"Found prefix match for request {request_id} with {len(allocated_blocks)} blocks")
cm = self.group_cache_managers[0]
cm.block_table[request_id] = allocated_blocks
return len(allocated_blocks) * self.block_size

def mark_blocks_as_completed(self, state: RequestState) -> None:
"""Marks the blocks that have been computed in the forward pass as such. If prefix sharing is off, this is a
no-op."""
num_completed_blocks = 0 if not self.use_prefix_sharing else self.blocks_to_complete.pop(state.request_id)
if num_completed_blocks == 0:
return None
cm = self.group_cache_managers[0] # if prefix sharing is on, there is only one group
self._block_manager.mark_blocks_as_computed(
num_completed_blocks=num_completed_blocks,
allocated_blocks=cm.block_table[state.request_id],
prompt_ids=(state.full_prompt_ids + state.static_outputs),
)


# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
Expand Down Expand Up @@ -469,6 +506,8 @@ def compute_num_blocks_and_max_batch_tokens(
2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
])

If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial.
"""
cache_memory = self.get_available_memory(max_memory_percent)
logger.info(f"Cache memory: {cache_memory}")
Expand All @@ -480,11 +519,16 @@ def compute_num_blocks_and_max_batch_tokens(
c = -cache_memory
logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")

# Compute discriminant and greatest solution
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
# If num_attention_masks is 0, the equation simplifies to a 1st degree polynomial
if self.num_attention_masks == 0:
greatest_solution = -c / b
# Otherwise, we solve the quadratic equation
else:
discriminant = b**2 - 4 * a * c
if discriminant < 0:
raise ValueError(f"Discriminant is negative: {discriminant = }")
greatest_solution = (-b + sqrt(discriminant)) / (2 * a)

if greatest_solution < 0:
raise ValueError(f"Greatest solution is negative: {greatest_solution = }")

Expand Down
Loading
Loading