Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def batch_generate(
parser.add_argument("--cuda-graph", "-cg", help="Use cuda graphs", type=str, default=None)
parser.add_argument("--compile", action="store_true", help="Compile the model using torch.compile")
parser.add_argument("--do-sample", action="store_true", help="Activate sampling")
parser.add_argument("--num-return-sequences", type=int, default=1, help="Number of return sequences")

# Benchmark parameters
parser.add_argument("--samples", type=int, default=500, help="Number of samples to generate")
Expand All @@ -190,6 +191,7 @@ def batch_generate(
parser.add_argument("--compare", action="store_true", help="Compare CB generation with classic generate")
parser.add_argument("--profile", type=str, default=None)
parser.add_argument("--metrics", action="store_true")
parser.add_argument("--seed", type=int, default=None, help="Random seed")

# Display parameters
parser.add_argument("--displayed", type=int, default=0, help="Number of samples to display")
Expand All @@ -210,6 +212,10 @@ def batch_generate(
else:
args.attn = "kernels-community/flash-attn3"

# Set seed
if args.seed is not None:
torch.manual_seed(args.seed)

# Create model
model_id = "google/gemma-2-2b-it" if args.sliding_window > 0 else "meta-llama/Llama-3.1-8B-Instruct"
has_system_role = args.sliding_window == 0
Expand Down Expand Up @@ -272,17 +278,27 @@ def batch_generate(
inputs = inputs if isinstance(inputs, list) else inputs["input_ids"]
batched_inputs.append(inputs)

# If num_return_sequences > 1, automatically enable do_sample with a warning
do_sample = args.do_sample
if args.num_return_sequences != 1 and not args.do_sample:
logger.warning(
f"num_return_sequences={args.num_return_sequences} > 1, automatically enabling do_sample=True. "
"Set --do-sample explicitly to suppress this warning."
)
do_sample = True

# Prepare generation config
generation_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens,
use_cuda_graph=use_cuda_graph,
eos_token_id=tokenizer.pad_token_id if args.force_max_length else tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=args.do_sample,
do_sample=do_sample,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
max_batch_tokens=args.max_batch_tokens,
num_return_sequences=args.num_return_sequences,
)

# Add a compile config if requested
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
# We add two extra tokens to the cache to handle padding and generally discard unwanted tokens
self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim)
self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim)
for _ in range(group_size):
new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -388,6 +388,29 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None:
prompt_ids=(state.initial_tokens + state.generated_tokens),
)

def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.

"""Copy the cache from the source blocks to the forked blocks."""
source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
Comment on lines +393 to +394
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that

Copy link
Collaborator Author

@remi-or remi-or Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later

for key_cache, value_cache in zip(self.key_cache, self.value_cache):
key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim)
key_cache[forked_blocks] = key_cache[source_blocks]
value_cache[forked_blocks] = value_cache[source_blocks]
# FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape)
# This will allow for better .update and a single copy instead of one per cache tensor

def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]:
"""Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids)."""
# These lists will be the accumulators for the source and destination blocks for the cache copy
source_blocks, destination_blocks = [], []
# Main fork loop
for cm in self.group_cache_managers:
src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager)
source_blocks.extend(src_blocks)
destination_blocks.extend(dst_blocks)
return source_blocks, destination_blocks


# TODO: rework computation with the groups and their sizes
class PagedAttentionMemoryHandler:
Expand Down
69 changes: 69 additions & 0 deletions src/transformers/generation/continuous_batching/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,43 @@ def get_free_blocks(
# In both cases, we return the allocated block ids
return allocated_block_ids

def fork_blocks(
self, parent_blocks: list[int], num_forks: int, shareable: bool, group_id: int
) -> tuple[list[list[int]], list[int], list[int]]:
"""Fork a given list of (parent_blocks) as many times as (num_forks). If the blocks are (shareable), we use
reference on the blocks that are complete. Otherwise, we allocate new blocks and keep track of their indices to
later copy the physical cache."""
# First phase: reference all complete blocks
forked_by_reference = []
for block_id in parent_blocks:
block = self._id_to_block[block_id]
if shareable and block.is_complete:
forked_by_reference.append(block.id)
block.ref_count += num_forks
else:
break

# Early return if we have forked all blocks by reference
blocks_to_copy = len(parent_blocks) - len(forked_by_reference)
if blocks_to_copy == 0:
return [forked_by_reference[:] for _ in range(num_forks)], [], []

# From now on, each child will have its own list of blocks
forked_blocks_lists = []
copy_src = []
copy_dst = []

# Second phase: allocate new blocks if needed
parent_id = forked_by_reference[-1] if forked_by_reference else None
for _ in range(num_forks):
allocated_block_ids = self.get_free_blocks(blocks_to_copy, parent_id, shareable, group_id)
if allocated_block_ids is None:
return None, [], []
forked_blocks_lists.append(forked_by_reference + allocated_block_ids)
copy_src.extend(parent_blocks[-blocks_to_copy:])
copy_dst.extend(allocated_block_ids)
return forked_blocks_lists, copy_src, copy_dst

def increase_ref_count(self, block_id: int) -> None:
"""Increases the reference count of a given (block_id)."""
block = self._id_to_block[block_id]
Expand Down Expand Up @@ -243,6 +280,38 @@ def get_write_indices(self, request_id: str, past_length: int, query_length: int
def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> tuple[str, int]:
"""Returns the attention type of the cache allocator and the key sequence length for the given request_id."""

def fork_blocks(
self, parent_request_id: str, children_request_ids: list[str], block_manager: BlockManager
) -> tuple[list[int], list[int]]:
"""Forks the cache blocks of a (parent_request_id) to a list of (children_request_ids). To manage the blocks,
the (block_manager) is used. When forking, the child's block are either shared with the parent, or they need to
be copied from the parent. Hence we return two lists of blocks that need to be copied: one for the source and
one for the destination."""

# Sanity checks
if parent_request_id not in self.block_table:
raise ValueError(f"No block table found for request {parent_request_id}")
# TODO: this check is good in the current context but it might be too much + slow things down
for children_request_id in children_request_ids:
if children_request_id in self.block_table:
raise ValueError(f"Block table already exists for request {children_request_id}")

# Actual forking
parent_blocks = self.block_table[parent_request_id]
list_forked_blocks, copy_src, copy_dst = block_manager.fork_blocks(
parent_blocks=parent_blocks,
num_forks=len(children_request_ids),
shareable=self.uses_block_sharing,
group_id=self._index,
)
if list_forked_blocks is None:
raise ValueError(f"Failed to fork blocks for request {parent_request_id}")

# Update the block table for all children requests
for children_request_id, forked_blocks in zip(children_request_ids, list_forked_blocks):
self.block_table[children_request_id] = forked_blocks
return copy_src, copy_dst


class FullAttentionCacheAllocator(CacheAllocator):
"""Cache manager for a group of full attention layers."""
Expand Down
101 changes: 69 additions & 32 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,19 @@ def _maybe_send_output(self, state: RequestState) -> None:
def update_batch(self) -> None:
"""Update request states based on generated tokens."""
new_tokens = self._get_new_tokens(len(self.requests_in_batch))
for i, state in enumerate(self.requests_in_batch):
current_logits_index = 0
for state in self.requests_in_batch:
# If the request has no remaining prompt ids, it means prefill has already ended or just finished
if len(state.remaining_prefill_tokens) == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING
token = new_tokens[i]
# If there are no generated tokens yet, it means prefill just ended
if state.generated_len() == 0:
self.metrics.record_ttft_metric(state.created_time, state.request_id)
state.status = RequestStatus.DECODING

token = new_tokens[current_logits_index]
state.tokens_to_process = [token]
current_logits_index += 1

# Update the request and stop if it is complete
is_finished = state.update_and_check_completion(token)
# We mark the completed blocks as such
Expand All @@ -594,6 +600,27 @@ def update_batch(self) -> None:
else:
raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}")

# If some requests need to be forked, we do it now
copy_source, copy_destination = [], []
while self.scheduler._requests_to_fork:
# Get the number of children and reset it so it's not forked again
state = self.scheduler._requests_to_fork.pop()
num_children = state.num_children
state.num_children = 0
# Create the new request
new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)]
new_requests = [state.fork(new_request_id) for new_request_id in new_request_ids]
# Fork the cache
copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids)
copy_source.extend(copy_src)
copy_destination.extend(copy_dst)
# Add the new requests to the scheduler
for new_request in new_requests:
self.scheduler.active_requests[new_request.request_id] = new_request
# FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything

# The copy induced by the fork is done in one go
self.cache.copy_cache(copy_source, copy_destination)
if self.cache.get_num_free_blocks() == 0:
raise ValueError("No more free blocks")

Expand Down Expand Up @@ -760,29 +787,35 @@ def __init__(
num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension
allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers
"""
# Reloade paged version if necessary
# Reload paged version of the attention implementation if necessary
if "paged|" not in model.config._attn_implementation:
model.set_attn_implementation(f"paged|{model.config._attn_implementation}")

# Internal arguments
self.model = model.eval()
generation_config = model.generation_config if generation_config is None else generation_config
self.generation_config = generation_config
self.manual_eviction = manual_eviction
self._allow_block_sharing = allow_block_sharing
self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created

self.input_queue = queue.Queue(maxsize=max_queue_size)
self.output_queue = queue.Queue()
self.stop_event = threading.Event()
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
self.batch_processor: ContinuousBatchProcessor | None = None
self._generation_thread = None
self._request_counter = 0
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None

# Generation config related arguments
generation_config = model.generation_config if generation_config is None else generation_config
self.generation_config = generation_config
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another PR for this maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the log probs? Sure, I will add it to the todo list

self.do_sample = getattr(generation_config, "do_sample", True)
self.logit_processor = self.model._get_logits_processor(generation_config)
self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet
self.manual_eviction = manual_eviction
self.batch_processor: ContinuousBatchProcessor | None = None
self._allow_block_sharing = allow_block_sharing
self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created
self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1)

# self.model.generation_config.top_p = None NOTE: figure out why this was here

# Cuda graph behavior is determined below using either user-specified arguments or heuristics
self.use_cuda_graph = self._decide_use_cuda_graphs(
use_cuda_graph=getattr(generation_config, "use_cuda_graph", None),
num_q_padding_intervals=num_q_padding_intervals,
Expand All @@ -796,6 +829,7 @@ def __init__(
num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS
)

# Log probability generation is not supported yet (TODO)
if self.log_prob_generation:
raise NotImplementedError("log_prob_generation is not supported yet")

Expand Down Expand Up @@ -929,6 +963,7 @@ def add_request(
state = RequestState(
request_id=request_id,
initial_tokens=list(input_ids),
num_children=self.num_return_sequences - 1,
record_timestamps=record_timestamps,
tokens_to_process=list(input_ids),
max_new_tokens=max_new_tokens,
Expand Down Expand Up @@ -1226,24 +1261,26 @@ def generate_batch(

# Initialize manager with the batch inputs
results = {}
num_requests = len(inputs)
with (
self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_padding_intervals,
num_kv_cuda_graphs=num_kv_padding_intervals,
allow_block_sharing=allow_block_sharing,
block=True,
timeout=5,
) as manager,
logging_redirect_tqdm([logger]),
tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
) as pbar,
):
gen_cfg = self.generation_config if generation_config is None else generation_config
num_requests = len(inputs) * gen_cfg.num_return_sequences
# Prepare context managers for the main loop
manager_cm = self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_padding_intervals,
num_kv_cuda_graphs=num_kv_padding_intervals,
allow_block_sharing=allow_block_sharing,
block=True,
timeout=5,
)
logging_cm = logging_redirect_tqdm([logger])
pbar_cm = tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
)
Comment on lines +1267 to +1281
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can create a get cm func?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand this -- what would it return? It seems self.continuous_batching_context_manager(...) is the get cm

# Main loop
with manager_cm as manager, logging_cm, pbar_cm as pbar:
try:
manager.add_requests(
inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps
Expand Down
Loading