-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[CB] Support the num_return_sequences argument
#42921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
| def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.
| source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) | ||
| forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later
| key_cache[forked_blocks] = key_cache[source_blocks] | ||
| value_cache[forked_blocks] = value_cache[source_blocks] | ||
| # FIXME: should be one copy for al CMs with only the changing blocks | ||
| # FIXME: even once per fork batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy should be async as well (async=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cf. above
| source_blocks, forked_blocks = cm.fork_blocks(state.request_id, new_state.request_id, self._block_manager) | ||
| self.copy_cache(source_blocks, forked_blocks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here we should "schedule" the copy. Remember we are in Python and the gil is killing us
| """Fork a given list of (source_blocks) into a new list of forked_blocks. If the blocks are (shareable), we | ||
| reference the existing blocks when they are complete. Otherwise, we allocate new blocks if possible. The | ||
| (group_id) of the layer group the blocks belong to is also needed.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need do that shows in / out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you means by this sorry!
| while self.scheduler._requests_to_fork: | ||
| state = self.scheduler._requests_to_fork.pop() | ||
| num_children = state.num_children | ||
| state.num_children = 0 | ||
| for i in range(num_children): | ||
| # FIXME: if fork cant be done, create a new pending request without forking | ||
| new_request = self.cache.fork_request(state, f"{state.request_id}__child#{i}") | ||
| self.scheduler.active_requests[new_request.request_id] = new_request |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, we should make that async IMO (new status "FORKING" -> wait until forked? IDK but we need to bench a tad bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has changed to be done in batch, which without asynchronous mode is the best we can do for CPU side.
| # Generation config related arguments | ||
| generation_config = model.generation_config if generation_config is None else generation_config | ||
| self.generation_config = generation_config | ||
| self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another PR for this maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the log probs? Sure, I will add it to the todo list
| manager_cm = self.continuous_batching_context_manager( | ||
| generation_config=generation_config, | ||
| num_q_cuda_graphs=num_q_padding_intervals, | ||
| num_kv_cuda_graphs=num_kv_padding_intervals, | ||
| allow_block_sharing=allow_block_sharing, | ||
| block=True, | ||
| timeout=5, | ||
| ) | ||
| logging_cm = logging_redirect_tqdm([logger]) | ||
| pbar_cm = tqdm( | ||
| total=num_requests, | ||
| disable=(not progress_bar), | ||
| desc=f"Solving {num_requests} requests", | ||
| unit="request", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can create a get cm func?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand this -- what would it return? It seems self.continuous_batching_context_manager(...) is the get cm
|
|
||
| def fork(self, new_request_id: str) -> "RequestState": | ||
| """Fork the request into a new request with the same state expect for request_id, created_time and lifespan.""" | ||
| return RequestState( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there has to be a better way to do this no? deepcopy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will macro benchmark the two and it the deepcopy is on par / better we will move to that, agreed it's better looking
Summary
This draft PR relies on #42877 to be merged. It adds the options to fork requests during continuous batching, which duplicates the request and uses as much as possible the existing cache. This is then leveraged to make the
num_return_sequencesargument available in CB.This PR enables parallel decoding, which will be useful for RL workflows.
Performance
Draft status, no performance table
Tests
Draft status, have not ran the tests.
Sanity check
Draft status, have not ran the command, but generation looks good right now.