-
Notifications
You must be signed in to change notification settings - Fork 31.5k
[CB] Support the num_return_sequences argument
#42921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
199b9f4
fd686b7
0926e37
cd4696b
0ee16e5
6963053
a7cc167
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,7 +210,7 @@ def __init__( | |
| self.key_cache: list[torch.Tensor] = [] | ||
| self.value_cache: list[torch.Tensor] = [] | ||
| # We add two extra tokens to the cache to handle padding and generally discard unwanted tokens | ||
| self.cache_shape = (num_blocks * self.block_size + 2, self.num_key_value_heads, self.head_dim) | ||
| self.cache_shape = ((num_blocks + 2) * self.block_size, self.num_key_value_heads, self.head_dim) | ||
| for _ in range(group_size): | ||
| new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) | ||
| new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device) | ||
|
|
@@ -388,6 +388,29 @@ def mark_shareable_blocks_as_complete(self, state: RequestState) -> None: | |
| prompt_ids=(state.initial_tokens + state.generated_tokens), | ||
| ) | ||
|
|
||
| def copy_cache(self, source_blocks: list[int], forked_blocks: list[int]) -> None: | ||
| """Copy the cache from the source blocks to the forked blocks.""" | ||
| source_blocks = torch.tensor(source_blocks, device=self.device, dtype=torch.int32) | ||
| forked_blocks = torch.tensor(forked_blocks, device=self.device, dtype=torch.int32) | ||
|
Comment on lines
+393
to
+394
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is allocating memeory + might need a sync (tensor of list -> sync cpu GPU) we wanna avoid that
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried playing around with this and I was surprised this is the fastest alternative, which makes no sense to me. Will leave a TODO to deep dive later |
||
| for key_cache, value_cache in zip(self.key_cache, self.value_cache): | ||
| key_cache = key_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) | ||
| value_cache = value_cache.view(-1, self.block_size, self.num_key_value_heads, self.head_dim) | ||
| key_cache[forked_blocks] = key_cache[source_blocks] | ||
| value_cache[forked_blocks] = value_cache[source_blocks] | ||
| # FIXME: consolidate the cache into a single tensor of shape (group_size, 2, *self.k_or_v_cache_shape) | ||
| # This will allow for better .update and a single copy instead of one per cache tensor | ||
|
|
||
| def fork_request(self, source_request_id: str, destination_request_ids: list[str]) -> tuple[list[int], list[int]]: | ||
| """Fork the cache of a request (state) into the one of a list of requests with the given (dst_request_ids).""" | ||
| # These lists will be the accumulators for the source and destination blocks for the cache copy | ||
| source_blocks, destination_blocks = [], [] | ||
| # Main fork loop | ||
| for cm in self.group_cache_managers: | ||
| src_blocks, dst_blocks = cm.fork_blocks(source_request_id, destination_request_ids, self._block_manager) | ||
| source_blocks.extend(src_blocks) | ||
| destination_blocks.extend(dst_blocks) | ||
| return source_blocks, destination_blocks | ||
|
|
||
|
|
||
| # TODO: rework computation with the groups and their sizes | ||
| class PagedAttentionMemoryHandler: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -572,13 +572,19 @@ def _maybe_send_output(self, state: RequestState) -> None: | |
| def update_batch(self) -> None: | ||
| """Update request states based on generated tokens.""" | ||
| new_tokens = self._get_new_tokens(len(self.requests_in_batch)) | ||
| for i, state in enumerate(self.requests_in_batch): | ||
| current_logits_index = 0 | ||
| for state in self.requests_in_batch: | ||
| # If the request has no remaining prompt ids, it means prefill has already ended or just finished | ||
| if len(state.remaining_prefill_tokens) == 0: | ||
| self.metrics.record_ttft_metric(state.created_time, state.request_id) | ||
| state.status = RequestStatus.DECODING | ||
| token = new_tokens[i] | ||
| # If there are no generated tokens yet, it means prefill just ended | ||
| if state.generated_len() == 0: | ||
| self.metrics.record_ttft_metric(state.created_time, state.request_id) | ||
| state.status = RequestStatus.DECODING | ||
|
|
||
| token = new_tokens[current_logits_index] | ||
| state.tokens_to_process = [token] | ||
| current_logits_index += 1 | ||
|
|
||
| # Update the request and stop if it is complete | ||
| is_finished = state.update_and_check_completion(token) | ||
| # We mark the completed blocks as such | ||
|
|
@@ -594,6 +600,27 @@ def update_batch(self) -> None: | |
| else: | ||
| raise ValueError(f"Request {state.request_id} is in an unexpected state: {state.status}") | ||
|
|
||
| # If some requests need to be forked, we do it now | ||
| copy_source, copy_destination = [], [] | ||
| while self.scheduler._requests_to_fork: | ||
| # Get the number of children and reset it so it's not forked again | ||
| state = self.scheduler._requests_to_fork.pop() | ||
| num_children = state.num_children | ||
| state.num_children = 0 | ||
| # Create the new request | ||
| new_request_ids = [f"{state.request_id}__child#{i}" for i in range(num_children)] | ||
| new_requests = [state.fork(new_request_id) for new_request_id in new_request_ids] | ||
| # Fork the cache | ||
| copy_src, copy_dst = self.cache.fork_request(state.request_id, new_request_ids) | ||
| copy_source.extend(copy_src) | ||
| copy_destination.extend(copy_dst) | ||
| # Add the new requests to the scheduler | ||
| for new_request in new_requests: | ||
| self.scheduler.active_requests[new_request.request_id] = new_request | ||
| # FIXME: if fork cant be done, create a new pending request without forking instead of crashing everything | ||
|
|
||
| # The copy induced by the fork is done in one go | ||
| self.cache.copy_cache(copy_source, copy_destination) | ||
| if self.cache.get_num_free_blocks() == 0: | ||
| raise ValueError("No more free blocks") | ||
|
|
||
|
|
@@ -760,29 +787,35 @@ def __init__( | |
| num_kv_padding_intervals: (optional) Number of intervals used to pad the keys/values dimension | ||
| allow_block_sharing: (optional) Whether to allow block sharing if the model has some full attention layers | ||
| """ | ||
| # Reloade paged version if necessary | ||
| # Reload paged version of the attention implementation if necessary | ||
| if "paged|" not in model.config._attn_implementation: | ||
| model.set_attn_implementation(f"paged|{model.config._attn_implementation}") | ||
|
|
||
| # Internal arguments | ||
| self.model = model.eval() | ||
| generation_config = model.generation_config if generation_config is None else generation_config | ||
| self.generation_config = generation_config | ||
| self.manual_eviction = manual_eviction | ||
| self._allow_block_sharing = allow_block_sharing | ||
| self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created | ||
|
|
||
| self.input_queue = queue.Queue(maxsize=max_queue_size) | ||
| self.output_queue = queue.Queue() | ||
| self.stop_event = threading.Event() | ||
| self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) | ||
| self.batch_processor: ContinuousBatchProcessor | None = None | ||
| self._generation_thread = None | ||
| self._request_counter = 0 | ||
| self._request_lock = threading.Lock() | ||
| self.model.generation_config.top_p = None | ||
|
|
||
| # Generation config related arguments | ||
| generation_config = model.generation_config if generation_config is None else generation_config | ||
| self.generation_config = generation_config | ||
| self.log_prob_generation = getattr(generation_config, "log_prob_generation", False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another PR for this maybe?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the log probs? Sure, I will add it to the todo list |
||
| self.do_sample = getattr(generation_config, "do_sample", True) | ||
| self.logit_processor = self.model._get_logits_processor(generation_config) | ||
| self.profile = getattr(generation_config, "profile", False) # TODO: not supported yet | ||
| self.manual_eviction = manual_eviction | ||
| self.batch_processor: ContinuousBatchProcessor | None = None | ||
| self._allow_block_sharing = allow_block_sharing | ||
| self._use_prefix_sharing = allow_block_sharing # approximation until the cache is created | ||
| self.num_return_sequences = getattr(generation_config, "num_return_sequences", 1) | ||
|
|
||
| # self.model.generation_config.top_p = None NOTE: figure out why this was here | ||
|
|
||
| # Cuda graph behavior is determined below using either user-specified arguments or heuristics | ||
| self.use_cuda_graph = self._decide_use_cuda_graphs( | ||
| use_cuda_graph=getattr(generation_config, "use_cuda_graph", None), | ||
| num_q_padding_intervals=num_q_padding_intervals, | ||
|
|
@@ -796,6 +829,7 @@ def __init__( | |
| num_kv_padding_intervals if num_kv_padding_intervals > 0 else NUM_KV_PADDING_INTERVALS | ||
| ) | ||
|
|
||
| # Log probability generation is not supported yet (TODO) | ||
| if self.log_prob_generation: | ||
| raise NotImplementedError("log_prob_generation is not supported yet") | ||
|
|
||
|
|
@@ -929,6 +963,7 @@ def add_request( | |
| state = RequestState( | ||
| request_id=request_id, | ||
| initial_tokens=list(input_ids), | ||
| num_children=self.num_return_sequences - 1, | ||
remi-or marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| record_timestamps=record_timestamps, | ||
| tokens_to_process=list(input_ids), | ||
| max_new_tokens=max_new_tokens, | ||
|
|
@@ -1226,24 +1261,26 @@ def generate_batch( | |
|
|
||
| # Initialize manager with the batch inputs | ||
| results = {} | ||
| num_requests = len(inputs) | ||
| with ( | ||
| self.continuous_batching_context_manager( | ||
| generation_config=generation_config, | ||
| num_q_cuda_graphs=num_q_padding_intervals, | ||
| num_kv_cuda_graphs=num_kv_padding_intervals, | ||
| allow_block_sharing=allow_block_sharing, | ||
| block=True, | ||
| timeout=5, | ||
| ) as manager, | ||
| logging_redirect_tqdm([logger]), | ||
| tqdm( | ||
| total=num_requests, | ||
| disable=(not progress_bar), | ||
| desc=f"Solving {num_requests} requests", | ||
| unit="request", | ||
| ) as pbar, | ||
| ): | ||
| gen_cfg = self.generation_config if generation_config is None else generation_config | ||
| num_requests = len(inputs) * gen_cfg.num_return_sequences | ||
| # Prepare context managers for the main loop | ||
| manager_cm = self.continuous_batching_context_manager( | ||
| generation_config=generation_config, | ||
| num_q_cuda_graphs=num_q_padding_intervals, | ||
| num_kv_cuda_graphs=num_kv_padding_intervals, | ||
| allow_block_sharing=allow_block_sharing, | ||
| block=True, | ||
| timeout=5, | ||
| ) | ||
| logging_cm = logging_redirect_tqdm([logger]) | ||
| pbar_cm = tqdm( | ||
| total=num_requests, | ||
| disable=(not progress_bar), | ||
| desc=f"Solving {num_requests} requests", | ||
| unit="request", | ||
| ) | ||
|
Comment on lines
+1267
to
+1281
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can create a get cm func?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure I understand this -- what would it return? It seems |
||
| # Main loop | ||
| with manager_cm as manager, logging_cm, pbar_cm as pbar: | ||
| try: | ||
| manager.add_requests( | ||
| inputs=inputs, max_new_tokens=kwargs.get("max_new_tokens"), record_timestamps=record_timestamps | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very interesting, I would assume we want to delay requests that are getting forked to the next batch to do this async (I might be wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the opposite, to maximize prefix sharing, you want to schedule those request asap. But there might be something to the idea that we can do much of the cpu-side of forking in the async. The issue is that there will always be a copy of the cache, hence GPU intervenes, but maybe it can be done in a side stream.
I think the best compromise is to add the feature now and later, when we get to CPU asynchronous-ness, we can add the FORKING status to let the scheduler know we need those requests to not be scheduled -- until the cache has been copied.