-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Major Change][Undecided yet] Move to FlashDecoding instead of PagedA…
…ttention kernel. (#1940) * Using flash decoding Conditional flashdecoding. Fix max_q. Working kvcache Working version with flash decoding. Make it work for mistral. Fix after rebase.. Less intrusive. REvert changes in modeling. Speedup flashdecoding. HHachweew Hack to make other models work. Fixing non flash decoding llama path. Router logic knows about page size. Missing 2 models. Missing cohere. Fixing cohere flash decoding. Revamped all this architecture. Fix cohere. Fixing falcon. Enabling custom block size schedule. Update router/src/infer.rs Not sending preallocated output. * Making it work on non flash decoding. * Fix Cohere. * Fix non decoding paths. * Rebased. * No need for cache_manager anymore. * Update? * "ipex" -> "cpu" * These do not belong. * Factoring cu_seqlen_qk for better abstracting over every model. * Fixing non flash tests/imports. * Changing return everywhere. * Update mistral past. * Fixing Mi{s,x}tral (non functional in Flash Decoding mode though). * Fixup mistral clamping (had issues with cuda graphs). * No need to recreate anything actually.
- Loading branch information
Showing
24 changed files
with
222 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from dataclasses import dataclass | ||
from text_generation_server.models.globals import FLASH_DECODING | ||
import torch | ||
from typing import Optional | ||
|
||
|
||
if FLASH_DECODING: | ||
|
||
@dataclass | ||
class Seqlen: | ||
input_lengths: torch.Tensor | ||
cu_seqlen_q: Optional[torch.Tensor] | ||
cu_seqlen_k: Optional[torch.Tensor] | ||
|
||
def __init__(self, input_lengths): | ||
self.input_lengths = input_lengths | ||
device = self.input_lengths.device | ||
shape = self.input_lengths.shape | ||
cu_seqlen_q = torch.arange( | ||
shape[0] + 1, | ||
device=device, | ||
dtype=torch.int32, | ||
) | ||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) | ||
# cuda graphs don't like this and this is necessary to clamp within mistral | ||
# Although FA2 might not want the clamping | ||
# cu_seqlen_k[0] = 0 | ||
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) | ||
|
||
self.cu_seqlen_q = cu_seqlen_q | ||
self.cu_seqlen_k = cu_seqlen_k | ||
|
||
def clamp(self, max): | ||
# Flash decoding doesn't need to clamp | ||
return self | ||
|
||
else: | ||
|
||
@dataclass | ||
class Seqlen: | ||
input_lengths: torch.Tensor | ||
|
||
def clamp(self, max): | ||
return Seqlen(torch.clamp(self.input_lengths, max=max)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.