Skip to content

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Dec 17, 2025

Summary

This draft PR relies on #42877 to be merged. It adds the options to fork requests during continuous batching, which duplicates the request and uses as much as possible the existing cache. This is then leveraged to make the num_return_sequences argument available in CB.
This PR enables parallel decoding, which will be useful for RL workflows.

Performance

Draft status, no performance table

Tests

Draft status, have not ran the tests.

Sanity check

Draft status, have not ran the command, but generation looks good right now.

@remi-or remi-or self-assigned this Dec 17, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

prompt_ids=(state.initial_tokens + state.generated_tokens),
)

def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.

Comment on lines +391 to +394
source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32)
forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that

Copy link
Collaborator Author

@remi-or remi-or Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later

key_cache[forked_blocks] = key_cache[source_blocks]
value_cache[forked_blocks] = value_cache[source_blocks]
# FIXME: should be one copy for al CMs with only the changing blocks
# FIXME: even once per fork batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy should be async as well (async=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. above

Comment on lines 406 to 407
source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager)
self.copy_cache(source_blocks, forked_blocks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here we should "schedule" the copy. Remember we are in Python and the gil is killing us

Comment on lines 127 to 129
"""Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we
reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The
(group_id) of the layer group the blocks belong to is also needed."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need do that shows in / out

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you means by this sorry!

Comment on lines 604 to 619
while self.scheduler._requests_to_fork:
state = self.scheduler._requests_to_fork.pop()
num_children = state.num_children
state.num_children = 0
for i in range(num_children):
# FIXME: if fork cant be done, create a new pending request without forking
new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}")
self.scheduler.active_requests[new_request.request_id] = new_request
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we should make that async IMO (new status "FORKING" -> wait until forked? IDK but we need to bench a tad bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has changed to be done in batch, which without asynchronous mode is the best we can do for CPU side.

# Generation config related arguments
generation_config = model.generation_config if generation_config is None else generation_config
self.generation_config = generation_config
self.log_prob_generation = getattr(generation_config, "log_prob_generation", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another PR for this maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the log probs? Sure, I will add it to the todo list

Comment on lines +1256 to +1281
manager_cm = self.continuous_batching_context_manager(
generation_config=generation_config,
num_q_cuda_graphs=num_q_padding_intervals,
num_kv_cuda_graphs=num_kv_padding_intervals,
allow_block_sharing=allow_block_sharing,
block=True,
timeout=5,
)
logging_cm = logging_redirect_tqdm([logger])
pbar_cm = tqdm(
total=num_requests,
disable=(not progress_bar),
desc=f"Solving {num_requests} requests",
unit="request",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can create a get cm func?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I understand this -- what would it return? It seems self.continuous_batching_context_manager(...) is the get cm


def fork(self, new_request_id: str) -> "RequestState":
"""Fork the request into a new request with the same state expect for request_id, created_time and lifespan."""
return RequestState(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there has to be a better way to do this no? deepcopy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will macro benchmark the two and it the deepcopy is on par / better we will move to that, agreed it's better looking

Base automatically changed from cb-block-sharing to main December 18, 2025 11:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants