From e415b690a68d7a0e149c996e46def41c867ff421 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Aug 2024 16:29:01 +0200 Subject: [PATCH] Lots of improvements (Still 2 allocators) (#2449) * Making prefix/flashinfer the default and testing the full release tests. * Include flashinfer in the docker. * Using prebuilt. * Allowing window_left_size (dummy version). * Disabling flashinfer/prefix caching on odd head_dim * Disable prefix caching for lora. * More specific codes. * Update lock * Updating integration tests with new values with FI/FD. Remove paged as a default too, and using FD everywhere. * Update cargo lock ? * Upgrade to 1.80 because of bitstream... * Everywhere 1.80 * Forgot last default place. * Apply suggestions from code review Co-authored-by: drbh * Updated flake lock * Tmp * Upgrade resolution system for less errors in resolution. * Remove lambda for cleaner function. * Handling debugger. * OVerride the env in server tests. * Is this enough to make it work ? * This seems to be working. * Downgrade some logs. * Fixing the default for vlm. * Don't enable prefix caching on VLM just yet. * Change `add_special_tokens` in order to have the correct tokens for chat input and not (since it's super important with the prefixing now) * Fixing prefix caching for flashdecoding. * Update all models. * Fixed flashinfer version. * add_special_tokens is internal only * Fixing seqlen with the new vlms. * Fixing the issue with `add_special_tokens` not being passed around. * Fixing the test. * Removing encoder_decoder (seq2seq). * Update the chat test. * Fixing the batching tokenization in flash causal lm. * Truncating left for radix purposes. * Oops this doesn't belong here. * Put back default pure shell. * Update server tests - Default to throughput test in k6 - Use TGI_WIGGLE_ROOM to adjust wiggle room * Only n_heads / process_group.size() are necessary. * Revert the integrationt tests change (seem linked to head_size modification). * Adding error message when assert is violated. * Fixing the free algorithm to handle times where the common prefix is smaller. * Apply suggestions from code review Co-authored-by: OlivierDehaene * Update server/text_generation_server/layers/attention/common.py Co-authored-by: OlivierDehaene * Fix disabling prefix caching - Fix windowing checks. * Revert the Cohere tokenizer change (for now using a revision instead). * Fmt. --------- Co-authored-by: drbh Co-authored-by: OlivierDehaene --- .github/workflows/tests.yaml | 2 +- Cargo.lock | 437 +++++++++--------- Dockerfile | 9 +- Dockerfile_amd | 2 +- Dockerfile_intel | 2 +- backends/client/src/v3/client.rs | 2 + backends/client/src/v3/sharded_client.rs | 1 + backends/v3/src/backend.rs | 30 +- backends/v3/src/block_allocator.rs | 5 +- backends/v3/src/client/grpc_client.rs | 1 + backends/v3/src/client/sharded_client.rs | 1 + backends/v3/src/queue.rs | 2 + backends/v3/src/radix.rs | 206 ++++++--- benchmark/src/generation.rs | 1 + flake.lock | 6 +- .../test_flash_llama_simple.json | 12 +- ..._llama_completion_many_prompts_stream.json | 172 +++---- .../test_flash_deepseek_v2.json | 44 +- .../test_flash_deepseek_v2_load.json | 152 +++--- .../test_flash_llama_fp8_all_params.json | 58 +-- .../test_flash_starcoder2_default_params.json | 16 +- .../test_flash_idefics2_next_all_params.json | 8 +- integration-tests/models/test_chat_llama.py | 2 +- launcher/src/main.rs | 267 ++++++++--- load_tests/common.js | 26 +- proto/v3/generate.proto | 2 + router/src/infer/mod.rs | 3 +- router/src/lib.rs | 21 + router/src/server.rs | 4 + router/src/validation.rs | 49 +- rust-toolchain.toml | 2 +- server/Makefile | 1 + server/Makefile-flashinfer | 2 + server/tests/conftest.py | 5 +- .../layers/attention/common.py | 39 +- .../layers/attention/cuda.py | 28 +- .../text_generation_server/models/__init__.py | 15 +- .../custom_modeling/flash_cohere_modeling.py | 23 +- .../custom_modeling/flash_dbrx_modeling.py | 27 +- .../flash_deepseek_v2_modeling.py | 24 +- .../custom_modeling/flash_gemma2_modeling.py | 23 +- .../custom_modeling/flash_gemma_modeling.py | 23 +- .../custom_modeling/flash_gpt2_modeling.py | 23 +- .../custom_modeling/flash_gptj_modeling.py | 25 +- .../custom_modeling/flash_llama_modeling.py | 23 +- .../custom_modeling/flash_mistral_modeling.py | 25 +- .../custom_modeling/flash_mixtral_modeling.py | 25 +- .../custom_modeling/flash_neox_modeling.py | 25 +- .../flash_pali_gemma_modeling.py | 5 +- .../custom_modeling/flash_phi_modeling.py | 23 +- .../custom_modeling/flash_qwen2_modeling.py | 25 +- .../custom_modeling/flash_rw_modeling.py | 33 +- .../flash_santacoder_modeling.py | 23 +- .../flash_starcoder2_modeling.py | 25 +- .../models/custom_modeling/idefics2.py | 5 +- .../models/custom_modeling/llava_next.py | 5 +- .../models/flash_causal_lm.py | 105 +++-- .../text_generation_server/models/globals.py | 9 +- .../models/vlm_causal_lm.py | 11 +- 59 files changed, 1235 insertions(+), 935 deletions(-) create mode 100644 server/Makefile-flashinfer diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f983b6ed85a..6faabe3b030 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: with: # Released on: 02 May, 2024 # https://releases.rs/docs/1.78.0/ - toolchain: 1.79.0 + toolchain: 1.80.0 override: true components: rustfmt, clippy - name: Install Protoc diff --git a/Cargo.lock b/Cargo.lock index 02e91bc1833..00c7f005923 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "ahash" version = "0.8.11" @@ -28,7 +34,7 @@ dependencies = [ "once_cell", "serde", "version_check", - "zerocopy 0.7.35", + "zerocopy", ] [[package]] @@ -121,14 +127,14 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "async-rustls" @@ -160,7 +166,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -171,7 +177,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -257,9 +263,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e89b6941c2d1a7045538884d6e760ccfffdf8e1ffc2613d8efa74305e1f3752" +checksum = "0f0e249228c6ad2d240c2dc94b714d711629d52bad946075d8e9b2f5391f0703" dependencies = [ "bindgen", "cc", @@ -402,7 +408,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.4", "object", "rustc-demangle", ] @@ -444,7 +450,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.72", + "syn 2.0.76", "which", ] @@ -483,9 +489,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bitstream-io" -version = "2.5.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" +checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452" [[package]] name = "block-buffer" @@ -516,9 +522,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" [[package]] name = "byteorder" @@ -534,15 +540,15 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.1" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "8318a53db07bb3f8dca91a600466bdb3f2eaadeedfdbcf02e1accbad9271ba50" [[package]] name = "camino" -version = "1.1.7" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" dependencies = [ "serde", ] @@ -584,12 +590,13 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.7" +version = "1.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a5c3fd7bfa1ce3897a3a3501d362b2d87b7f2583ebcb4a949ec25911025cbc" +checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -623,6 +630,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clang-sys" version = "1.8.1" @@ -647,9 +660,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.11" +version = "4.5.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35723e6a11662c2afb578bcf0b88bf6ea8e21282a953428f240574fcc3a2b5b3" +checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019" dependencies = [ "clap_builder", "clap_derive", @@ -657,9 +670,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.11" +version = "4.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49eb96cbfa7cfa35017b7cd548c75b14c3118c98b423041d70562665e07fb0fa" +checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6" dependencies = [ "anstream", "anstyle", @@ -669,14 +682,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.11" +version = "4.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d029b67f89d30bbb547c89fd5161293c0aec155fc691d7924b64550662db93e" +checksum = "501d359d5f3dcaf6ecdeee48833ae73ec6e42723a1e52419c79abf9507eec0a0" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -687,9 +700,9 @@ checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "cmake" -version = "0.1.50" +version = "0.1.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" dependencies = [ "cc", ] @@ -741,15 +754,15 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "51e852e6dc9a5bed1fae92dd2375037bf2b768725bf3be87811edee3249d09ad" dependencies = [ "libc", ] @@ -897,19 +910,19 @@ dependencies = [ [[package]] name = "ctrlc" -version = "3.4.4" +version = "3.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345" +checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3" dependencies = [ - "nix", - "windows-sys 0.52.0", + "nix 0.29.0", + "windows-sys 0.59.0", ] [[package]] name = "cxx" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "273dcfd3acd4e1e276af13ed2a43eea7001318823e7a726a6b3ed39b4acc0b82" +checksum = "3c4eae4b7fc8dcb0032eb3b1beee46b38d371cdeaf2d0c64b9944f6f69ad7755" dependencies = [ "cc", "cxxbridge-flags", @@ -919,9 +932,9 @@ dependencies = [ [[package]] name = "cxx-build" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8b2766fbd92be34e9ed143898fce6c572dc009de39506ed6903e5a05b68914e" +checksum = "6c822bf7fb755d97328d6c337120b6f843678178751cba33c9da25cf522272e0" dependencies = [ "cc", "codespan-reporting", @@ -929,24 +942,24 @@ dependencies = [ "proc-macro2", "quote", "scratch", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "cxxbridge-flags" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "839fcd5e43464614ffaa989eaf1c139ef1f0c51672a1ed08023307fa1b909ccd" +checksum = "719d6197dc016c88744aff3c0d0340a01ecce12e8939fc282e7c8f583ee64bc6" [[package]] name = "cxxbridge-macro" -version = "1.0.124" +version = "1.0.126" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2c1c1776b986979be68bb2285da855f8d8a35851a769fca8740df7c3d07877" +checksum = "35de3b547387863c8f82013c4f79f1c2162edee956383e4089e1d04c18c4f16c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -970,7 +983,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -981,7 +994,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1011,7 +1024,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1021,7 +1034,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" dependencies = [ "derive_builder_core", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1057,9 +1070,9 @@ dependencies = [ [[package]] name = "dunce" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56ce8c6da7551ec6c462cbaf3bfbc75131ebbfa1c944aeaa9dab51ca1c5f0c3b" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "easy-cast" @@ -1126,7 +1139,7 @@ dependencies = [ "flume", "half 2.4.1", "lebe", - "miniz_oxide", + "miniz_oxide 0.7.4", "rayon-core", "smallvec", "zune-inflate", @@ -1144,9 +1157,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fdeflate" @@ -1165,12 +1178,12 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.0.30" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" dependencies = [ "crc32fast", - "miniz_oxide", + "miniz_oxide 0.8.0", ] [[package]] @@ -1296,7 +1309,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1405,7 +1418,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", @@ -1414,9 +1427,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" dependencies = [ "atomic-waker", "bytes", @@ -1424,7 +1437,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.2.6", + "indexmap 2.4.0", "slab", "tokio", "tokio-util", @@ -1631,7 +1644,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -1689,9 +1702,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", @@ -1774,9 +1787,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.6" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c" dependencies = [ "equivalent", "hashbrown 0.14.5", @@ -1832,7 +1845,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -1915,9 +1928,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] @@ -1932,7 +1945,7 @@ dependencies = [ "anyhow", "base64 0.21.7", "bytecount", - "clap 4.5.11", + "clap 4.5.16", "fancy-regex", "fraction", "getrandom", @@ -1972,9 +1985,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "libfuzzer-sys" @@ -2126,7 +2139,7 @@ dependencies = [ "hyper 1.4.1", "hyper-rustls", "hyper-util", - "indexmap 2.2.6", + "indexmap 2.4.0", "ipnet", "metrics", "metrics-util", @@ -2179,9 +2192,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6853ef2340c668281c5ea86b04da2ebb2fc9e98a7185a887591de4cac945d5b5" +checksum = "744a2b84dbd22398e347594ed2aef9d3f1b948934e3e6e94ef69ecd39d597f4b" dependencies = [ "minijinja", "serde", @@ -2203,6 +2216,15 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "miniz_oxide" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "0.8.11" @@ -2217,9 +2239,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi 0.3.9", "libc", @@ -2251,7 +2273,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2341,7 +2363,19 @@ checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ "bitflags 2.6.0", "cfg-if", - "cfg_aliases", + "cfg_aliases 0.1.1", + "libc", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "cfg_aliases 0.2.1", "libc", ] @@ -2439,7 +2473,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2510,9 +2544,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "object" -version = "0.36.2" +version = "0.36.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f203fa8daa7bb185f760ae12bd8e097f63d17041dcdcaf675ac54cdf863170e" +checksum = "27b64972346851a39438c60b341ebc01bba47464ae329e55cf343eb93964efd9" dependencies = [ "memchr", ] @@ -2574,7 +2608,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2613,7 +2647,7 @@ checksum = "1e32339a5dc40459130b3bd269e9892439f55b33e772d2a9d402a789baaf4e8a" dependencies = [ "futures-core", "futures-sink", - "indexmap 2.2.6", + "indexmap 2.4.0", "js-sys", "once_cell", "pin-project-lite", @@ -2837,7 +2871,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap 2.4.0", ] [[package]] @@ -2857,7 +2891,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2916,7 +2950,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide", + "miniz_oxide 0.7.4", ] [[package]] @@ -2933,21 +2967,21 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dee4364d9f3b902ef14fab8a1ddffb783a1cb6b4bba3bfc1fa3922732c7de97f" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy 0.6.6", + "zerocopy", ] [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" dependencies = [ "proc-macro2", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -2999,7 +3033,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" dependencies = [ "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3039,7 +3073,7 @@ dependencies = [ "prost 0.12.6", "prost-types", "regex", - "syn 2.0.72", + "syn 2.0.76", "tempfile", ] @@ -3066,7 +3100,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3110,9 +3144,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -3201,9 +3235,9 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5797d09f9bd33604689e87e8380df4951d4912f01b63f71205e2abd4ae25e6b6" +checksum = "a8f0bfd976333248de2078d350bfdf182ff96e168a24d23d2436cef320dd4bdd" dependencies = [ "avif-serialize", "imgref", @@ -3264,9 +3298,9 @@ dependencies = [ [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -3275,9 +3309,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.5" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -3359,9 +3393,9 @@ dependencies = [ [[package]] name = "rgb" -version = "0.8.45" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade4539f42266ded9e755c605bdddf546242b2c961b03b06a7375260788a0523" +checksum = "0f86ae463694029097b846d8f99fd5536740602ae00022c0c50c5600720b2f71" dependencies = [ "bytemuck", ] @@ -3416,7 +3450,7 @@ dependencies = [ "proc-macro2", "quote", "rust-embed-utils", - "syn 2.0.72", + "syn 2.0.76", "walkdir", ] @@ -3507,12 +3541,12 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" +checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile 2.1.3", "rustls-pki-types", "schannel", "security-framework", @@ -3529,9 +3563,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -3539,15 +3573,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.6" +version = "0.102.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +checksum = "84678086bd54edf2b415183ed7a94d0efb049f1b646a33e22a36f3794be6ae56" dependencies = [ "aws-lc-rs", "ring 0.17.8", @@ -3641,9 +3675,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] @@ -3660,20 +3694,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] name = "serde_json" -version = "1.0.121" +version = "1.0.127" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" +checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" dependencies = [ "itoa", "memchr", @@ -3875,7 +3909,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -3897,9 +3931,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525" dependencies = [ "proc-macro2", "quote", @@ -3993,20 +4027,21 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4024,7 +4059,7 @@ version = "2.2.1-dev0" dependencies = [ "async-stream", "async-trait", - "clap 4.5.11", + "clap 4.5.16", "cmake", "cxx", "cxx-build", @@ -4046,7 +4081,7 @@ name = "text-generation-benchmark" version = "2.2.1-dev0" dependencies = [ "average", - "clap 4.5.11", + "clap 4.5.16", "crossterm", "float-ord", "hf-hub", @@ -4084,11 +4119,11 @@ dependencies = [ name = "text-generation-launcher" version = "2.2.1-dev0" dependencies = [ - "clap 4.5.11", + "clap 4.5.16", "ctrlc", "float_eq", "hf-hub", - "nix", + "nix 0.28.0", "once_cell", "reqwest", "serde", @@ -4108,7 +4143,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap 4.5.11", + "clap 4.5.16", "csv", "futures", "futures-util", @@ -4156,7 +4191,7 @@ dependencies = [ "axum 0.7.5", "axum-tracing-opentelemetry", "base64 0.22.1", - "clap 4.5.11", + "clap 4.5.16", "criterion", "futures", "futures-util", @@ -4224,7 +4259,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4341,14 +4376,14 @@ dependencies = [ [[package]] name = "tokio" -version = "1.39.2" +version = "1.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5" dependencies = [ "backtrace", "bytes", "libc", - "mio 1.0.1", + "mio 1.0.2", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -4375,7 +4410,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4437,9 +4472,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.16" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81967dd0dd2c1ab0bc3468bd7caecc32b8a4aa47d0c8c695d8c2b2108168d62c" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -4449,20 +4484,20 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8fb9f64314842840f1d940ac544da178732128f1c78c21772e876579e0da1db" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.17" +version = "0.22.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9f8729f5aea9562aac1cc0441f5d6de3cff1ee0c5d67293eeca5eb36ee7c16" +checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_spanned", "toml_datetime", @@ -4534,7 +4569,7 @@ dependencies = [ "proc-macro2", "prost-build", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4575,15 +4610,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -4605,7 +4640,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4865,7 +4900,7 @@ version = "4.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23" dependencies = [ - "indexmap 2.2.6", + "indexmap 2.4.0", "serde", "serde_json", "utoipa-gen", @@ -4881,7 +4916,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -4919,7 +4954,7 @@ checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5000,34 +5035,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" dependencies = [ "cfg-if", "js-sys", @@ -5037,9 +5073,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5047,28 +5083,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", @@ -5149,11 +5185,11 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d4cc384e1e73b93bafa6fb4f1df8c41695c8a91cf9c4c64358067d15a7b6c6b" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5208,6 +5244,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -5388,9 +5433,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.16" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b480ae9340fc261e6be3e95a1ba86d54ae3f9171132a73ce8d4bbaf68339507c" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" dependencies = [ "memchr", ] @@ -5405,34 +5450,14 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "zerocopy" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854e949ac82d619ee9a14c66a1b674ac730422372ccb759ce0c39cabcf2bf8e6" -dependencies = [ - "byteorder", - "zerocopy-derive 0.6.6", -] - [[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "zerocopy-derive 0.7.35", -] - -[[package]] -name = "zerocopy-derive" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "125139de3f6b9d625c39e2efdd73d41bdac468ccd556556440e322be0e1bbd91" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.72", + "byteorder", + "zerocopy-derive", ] [[package]] @@ -5443,7 +5468,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] @@ -5463,7 +5488,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.76", ] [[package]] diff --git a/Dockerfile b/Dockerfile index 4c64a643526..0d0e89b1b85 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -184,6 +184,12 @@ WORKDIR /usr/src COPY server/Makefile-selective-scan Makefile RUN make build-all +# Build flashinfer +FROM kernel-builder AS flashinfer-builder +WORKDIR /usr/src +COPY server/Makefile-flashinfer Makefile +RUN make install-flashinfer + # Text Generation Inference base image FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS base @@ -236,6 +242,7 @@ COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/c # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=flashinfer-builder /opt/conda/lib/python3.10/site-packages/flashinfer/ /opt/conda/lib/python3.10/site-packages/flashinfer/ # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/Dockerfile_amd b/Dockerfile_amd index cdad0d28b52..8cb699ddbb3 100644 --- a/Dockerfile_amd +++ b/Dockerfile_amd @@ -1,5 +1,5 @@ # Rust builder -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/Dockerfile_intel b/Dockerfile_intel index 12480c70ff8..9af6422c84c 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,6 +1,6 @@ ARG PLATFORM=xpu -FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.80 AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse diff --git a/backends/client/src/v3/client.rs b/backends/client/src/v3/client.rs index b321278c10f..479d31bf290 100644 --- a/backends/client/src/v3/client.rs +++ b/backends/client/src/v3/client.rs @@ -153,6 +153,8 @@ impl Client { }), // We truncate the input on the server side to be sure that it has the correct size truncate, + // Most request will have that + add_special_tokens: true, // Blocks and slots will be set on the server side if we use paged attention blocks: vec![], slots: vec![], diff --git a/backends/client/src/v3/sharded_client.rs b/backends/client/src/v3/sharded_client.rs index 1cc173e3301..645c076a26b 100644 --- a/backends/client/src/v3/sharded_client.rs +++ b/backends/client/src/v3/sharded_client.rs @@ -221,6 +221,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/backend.rs b/backends/v3/src/backend.rs index cbcbff72aca..05a26370579 100644 --- a/backends/v3/src/backend.rs +++ b/backends/v3/src/backend.rs @@ -35,27 +35,15 @@ impl BackendV3 { window_size: Option, speculate: u32, ) -> Self { - let prefix_caching = if let Ok(prefix_caching) = std::env::var("USE_PREFIX_CACHING") { - matches!(prefix_caching.as_str(), "true" | "1") - } else { - false - }; - let attention = if let Ok(attention) = std::env::var("ATTENTION") { - attention - .parse() - .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")) - } else if prefix_caching { - Attention::FlashInfer - } else { - Attention::Paged - }; - let block_size = if attention == Attention::FlashDecoding { - 256 - } else if attention == Attention::FlashInfer { - 1 - } else { - 16 - }; + let prefix_caching = + std::env::var("USE_PREFIX_CACHING").expect("Expect prefix caching env var"); + let prefix_caching = matches!(prefix_caching.as_str(), "true" | "1"); + let attention: String = std::env::var("ATTENTION").expect("attention env var"); + + let attention: Attention = attention + .parse() + .unwrap_or_else(|_| panic!("Invalid attention was specified :`{attention}`")); + let block_size = attention.block_size(); let queue = Queue::new( requires_padding, diff --git a/backends/v3/src/block_allocator.rs b/backends/v3/src/block_allocator.rs index c5503b8c6ce..4fea172b65a 100644 --- a/backends/v3/src/block_allocator.rs +++ b/backends/v3/src/block_allocator.rs @@ -1,4 +1,4 @@ -use std::{cmp::min, sync::Arc}; +use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; use crate::radix::RadixAllocator; @@ -137,7 +137,6 @@ pub trait Allocator { fn free(&mut self, blocks: Vec, allocation_id: u64); } - pub struct SimpleAllocator { free_blocks: Vec, block_size: u32, @@ -167,7 +166,7 @@ impl Allocator for SimpleAllocator { None => (tokens, 1), Some(window_size) => { let repeats = (tokens + window_size - 1) / window_size; - let tokens = min(tokens, window_size); + let tokens = core::cmp::min(tokens, window_size); (tokens, repeats as usize) } }; diff --git a/backends/v3/src/client/grpc_client.rs b/backends/v3/src/client/grpc_client.rs index 6282759e87c..648662db39b 100644 --- a/backends/v3/src/client/grpc_client.rs +++ b/backends/v3/src/client/grpc_client.rs @@ -149,6 +149,7 @@ impl Client { requests.push(Request { id: 0, inputs, + add_special_tokens: true, input_chunks: Some(Input { chunks: input_chunks, }), diff --git a/backends/v3/src/client/sharded_client.rs b/backends/v3/src/client/sharded_client.rs index 2f78da0349e..ea77a696648 100644 --- a/backends/v3/src/client/sharded_client.rs +++ b/backends/v3/src/client/sharded_client.rs @@ -222,6 +222,7 @@ impl Health for ShardedClient { chunks: vec![Chunk::Text("liveness".into()).into()], }), truncate: 10, + add_special_tokens: true, prefill_logprobs: false, parameters: Some(NextTokenChooserParameters { temperature: 1.0, diff --git a/backends/v3/src/queue.rs b/backends/v3/src/queue.rs index faa57c11362..2a8c4c53660 100644 --- a/backends/v3/src/queue.rs +++ b/backends/v3/src/queue.rs @@ -383,6 +383,7 @@ impl State { }), inputs: entry.request.inputs.chunks_to_string(), truncate: entry.request.truncate, + add_special_tokens: entry.request.add_special_tokens, parameters: Some(NextTokenChooserParameters::from( entry.request.parameters.clone(), )), @@ -517,6 +518,7 @@ mod tests { inputs: vec![], input_ids: Some(Arc::new(vec![])), input_length: 0, + add_special_tokens: true, truncate: 0, decoder_input_details: false, parameters: ValidParameters { diff --git a/backends/v3/src/radix.rs b/backends/v3/src/radix.rs index 5bac1a31f95..b85be00bb16 100644 --- a/backends/v3/src/radix.rs +++ b/backends/v3/src/radix.rs @@ -1,12 +1,10 @@ +use crate::block_allocator::{Allocator, BlockAllocation}; +use slotmap::{DefaultKey, SlotMap}; use std::{ collections::{BTreeSet, HashMap}, sync::Arc, }; -use slotmap::{DefaultKey, SlotMap}; - -use crate::block_allocator::{Allocator, BlockAllocation}; - pub struct RadixAllocator { allocation_id: u64, @@ -16,26 +14,26 @@ pub struct RadixAllocator { /// Blocks that are immediately available for allocation. free_blocks: Vec, + + #[allow(dead_code)] + // This isn't used because the prefix need to match without the windowing + // mecanism. This at worst is overallocating, not necessarily being wrong. + window_size: Option, + + block_size: u32, } impl RadixAllocator { pub fn new(block_size: u32, n_blocks: u32, window_size: Option) -> Self { - assert_eq!( - block_size, 1, - "Radix tree allocator only works with block_size=1, was: {}", - block_size - ); - if window_size.is_some() { - unimplemented!("Window size not supported in the prefix-caching block allocator yet"); - } - RadixAllocator { allocation_id: 0, allocations: HashMap::new(), - cache_blocks: RadixTrie::new(), + cache_blocks: RadixTrie::new(block_size as usize), // Block 0 is reserved for health checks. free_blocks: (1..n_blocks).collect(), + window_size, + block_size, } } @@ -63,6 +61,7 @@ impl RadixAllocator { } } +// Allocator trait impl Allocator for RadixAllocator { fn allocate( &mut self, @@ -86,10 +85,12 @@ impl Allocator for RadixAllocator { .incref(prefix_node) .expect("Failed to increment refcount"); - let prefix_len = blocks.len(); + let prefix_len = blocks.len() * self.block_size as usize; let suffix_len = tokens - prefix_len as u32; - match self.alloc_or_reclaim(suffix_len as usize) { + let suffix_blocks = (suffix_len + self.block_size - 1) / self.block_size; + + match self.alloc_or_reclaim(suffix_blocks as usize) { Some(suffix_blocks) => blocks.extend(suffix_blocks), None => { self.cache_blocks @@ -100,7 +101,20 @@ impl Allocator for RadixAllocator { } // 1:1 mapping of blocks and slots. - let slots = blocks.clone(); + let slots = if self.block_size == 1 { + blocks.clone() + } else { + let mut slots = Vec::with_capacity(blocks.len() * self.block_size as usize); + 'slots: for block_id in &blocks { + for s in (block_id * self.block_size)..((block_id + 1) * self.block_size) { + slots.push(s); + if slots.len() as u32 == tokens { + break 'slots; + } + } + } + slots + }; let allocation = RadixAllocation { prefix_node, @@ -108,6 +122,8 @@ impl Allocator for RadixAllocator { prefill_tokens: prefill_tokens.clone(), }; + tracing::debug!("Blocks {blocks:?}"); + self.allocation_id += 1; self.allocations.insert(self.allocation_id, allocation); @@ -136,27 +152,38 @@ impl Allocator for RadixAllocator { // If there are prefill tokens that did not come from the cache, // add them to the cache. if prefill_tokens.len() > allocation.cached_prefix_len { - let prefix_len = self - .cache_blocks - .insert(prefill_tokens, &blocks[..prefill_tokens.len()]) - // Unwrap, failing is a programming error. - .expect("Failed to store prefill tokens"); - - // We can have a prefill with the following structure: - // - // |---| From the prefix cache. - // A B C D E F G - //|--------| Found in the trie during insertion. - // - // This means that while processing this request there was a - // partially overlapping request that had A..=E in its - // prefill. In this case we need to free the blocks D E. - self.free_blocks - .extend(&blocks[allocation.cached_prefix_len..prefix_len]); + let aligned = + (prefill_tokens.len() / self.block_size as usize) * self.block_size as usize; + if aligned > 0 { + let prefix_len = self + .cache_blocks + .insert( + &prefill_tokens[..aligned], + &blocks[..aligned / self.block_size as usize], + ) + // Unwrap, failing is a programming error. + .expect("Failed to store prefill tokens"); + // We can have a prefill with the following structure: + // + // |---| From the prefix cache. + // A B C D E F G + //|--------| Found in the trie during insertion. + // + // This means that while processing this request there was a + // partially overlapping request that had A..=E in its + // prefill. In this case we need to free the blocks D E. + if prefix_len > allocation.cached_prefix_len { + self.free_blocks.extend( + &blocks[allocation.cached_prefix_len / self.block_size as usize + ..prefix_len / self.block_size as usize], + ); + } + } } // Free non-prefill blocks. - self.free_blocks.extend(&blocks[prefill_tokens.len()..]); + self.free_blocks + .extend(&blocks[prefill_tokens.len() / self.block_size as usize..]); } else { self.free_blocks.extend(blocks); } @@ -204,17 +231,14 @@ pub struct RadixTrie { /// Time as a monotonically increating counter to avoid the system /// call that a real time lookup would require. time: u64, -} -impl Default for RadixTrie { - fn default() -> Self { - Self::new() - } + /// All blocks need to be aligned with this + block_size: usize, } impl RadixTrie { /// Construct a new radix trie. - pub fn new() -> Self { + pub fn new(block_size: usize) -> Self { let root = TrieNode::new(vec![], vec![], 0, None); let mut nodes = SlotMap::new(); let root = nodes.insert(root); @@ -223,13 +247,14 @@ impl RadixTrie { nodes, root, time: 0, + block_size, } } /// Find the prefix of the given tokens. /// /// The blocks corresponding to the part of the prefix that could be found - /// are writteng to `blocks`. The number of blocks is in `0..=tokens.len()`. + /// are written to `blocks`. The number of blocks is in `0..=tokens.len()`. /// Returns the identifier of the trie node that contains the longest /// prefix. The node identifier can be used by callers to e.g. increase its /// reference count. @@ -247,8 +272,9 @@ impl RadixTrie { if let Some(&child_id) = node.children.get(&key[0]) { self.update_access_time(child_id); let child = self.nodes.get(child_id).expect("Invalid child identifier"); - let shared_prefix_len = child.key.shared_prefix_len(key); - blocks.extend(&child.blocks[..shared_prefix_len]); + let shared_prefix_len = shared_prefix(&child.key, key, self.block_size); + assert_eq!(shared_prefix_len % self.block_size, 0); + blocks.extend(&child.blocks[..shared_prefix_len / self.block_size]); let key = &key[shared_prefix_len..]; if !key.is_empty() { @@ -349,7 +375,8 @@ impl RadixTrie { /// the first 10 elements of the tree **the blocks are not updated**. pub fn insert(&mut self, tokens: &[u32], blocks: &[u32]) -> Result { self.time += 1; - self.insert_(self.root, tokens, blocks) + let common = self.insert_(self.root, tokens, blocks)?; + Ok(common) } /// Insertion worker. @@ -363,7 +390,7 @@ impl RadixTrie { // the part of the prefix that is already in the trie to detect // mismatches. - if tokens.len() != blocks.len() { + if tokens.len() != blocks.len() * self.block_size { return Err(TrieError::BlockTokenCountMismatch); } @@ -374,10 +401,10 @@ impl RadixTrie { .get_mut(child_id) // Unwrap here, since failure is a bug. .expect("Child node does not exist"); - let shared_prefix_len = child.key.shared_prefix_len(tokens); + let shared_prefix_len = shared_prefix(&child.key, tokens, self.block_size); // We are done, the prefix is already in the trie. - if shared_prefix_len == tokens.len() { + if shared_prefix_len == tokens.len() || shared_prefix_len == 0 { return Ok(shared_prefix_len); } @@ -387,7 +414,7 @@ impl RadixTrie { + self.insert_( child_id, &tokens[shared_prefix_len..], - &blocks[shared_prefix_len..], + &blocks[shared_prefix_len / self.block_size..], )?); } @@ -396,7 +423,7 @@ impl RadixTrie { // remainder of the prefix into the node again let child_id = self.split_node(child_id, shared_prefix_len); let key = &tokens[shared_prefix_len..]; - let blocks = &blocks[shared_prefix_len..]; + let blocks = &blocks[shared_prefix_len / self.block_size..]; Ok(shared_prefix_len + self.insert_(child_id, key, blocks)?) } else { self.add_node(node_id, tokens, blocks); @@ -550,34 +577,53 @@ impl TrieNode { } } -/// Helper trait to get the length of the shared prefix of two sequences. -trait SharedPrefixLen { - fn shared_prefix_len(&self, other: &Self) -> usize; -} - -impl SharedPrefixLen for [T] -where - T: PartialEq, -{ - fn shared_prefix_len(&self, other: &Self) -> usize { - self.iter().zip(other).take_while(|(a, b)| a == b).count() - } +fn shared_prefix(left: &[u32], right: &[u32], block_size: usize) -> usize { + let full = left.iter().zip(right).take_while(|(a, b)| a == b).count(); + (full / block_size) * block_size } #[cfg(test)] mod tests { use std::sync::Arc; - use crate::block_allocator::Allocator; + use super::*; - use super::RadixAllocator; + #[test] + fn allocator_block_size() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22, 23]); + assert_eq!(allocation.prefix_len, 4); + } + + #[test] + fn allocator_block_size_non_aligned() { + let mut cache = RadixAllocator::new(2, 12, None); + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 0); + cache.free(allocation.blocks.clone(), allocation.allocation_id); + + let allocation = cache.allocate(7, Some(Arc::new(vec![0, 1, 2]))).unwrap(); + assert_eq!(allocation.blocks, vec![8, 9, 10, 11]); + assert_eq!(allocation.slots, vec![16, 17, 18, 19, 20, 21, 22]); + assert_eq!(allocation.prefix_len, 2); + } #[test] fn allocator_reuses_prefixes() { let mut cache = RadixAllocator::new(1, 12, None); let allocation = cache.allocate(8, Some(Arc::new(vec![0, 1, 2, 3]))).unwrap(); assert_eq!(allocation.blocks, vec![4, 5, 6, 7, 8, 9, 10, 11]); - assert_eq!(allocation.slots, allocation.slots); + assert_eq!(allocation.blocks, allocation.slots); assert_eq!(allocation.prefix_len, 0); cache.free(allocation.blocks.clone(), allocation.allocation_id); @@ -666,7 +712,7 @@ mod tests { #[test] fn trie_insertions_have_correct_prefix_len() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); assert_eq!(trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(), 0); @@ -687,9 +733,33 @@ mod tests { ); } + #[test] + fn trie_insertions_block_size() { + let mut trie = RadixTrie::new(2); + + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 0); + + // Already exists. + // But needs to be block_size aligned + assert_eq!(trie.insert(&[0, 1, 2, 3], &[0, 1]).unwrap(), 4); + + // Completely new at root-level + assert_eq!(trie.insert(&[1, 2, 3, 4], &[1, 2]).unwrap(), 0); + + // Contains full prefix, but longer. + assert_eq!(trie.insert(&[0, 1, 2, 3, 4, 5], &[0, 1, 2]).unwrap(), 4); + + // Shares partial prefix, we need a split. + assert_eq!( + trie.insert(&[0, 1, 3, 4, 5, 6, 7, 8], &[0, 1, 2, 3]) + .unwrap(), + 2 + ); + } + #[test] fn trie_get_returns_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[1, 2, 3], &[1, 2, 3]).unwrap(); trie.insert(&[0, 1, 2, 3, 4], &[0, 1, 2, 3, 4]).unwrap(); @@ -723,7 +793,7 @@ mod tests { #[test] fn trie_evict_removes_correct_blocks() { - let mut trie = super::RadixTrie::new(); + let mut trie = RadixTrie::new(1); trie.insert(&[0, 1, 2], &[0, 1, 2]).unwrap(); trie.insert(&[0, 1, 2, 3, 5, 6, 7], &[0, 1, 2, 3, 5, 6, 7]) .unwrap(); diff --git a/benchmark/src/generation.rs b/benchmark/src/generation.rs index 7494d5b5d8c..789c7b514fc 100644 --- a/benchmark/src/generation.rs +++ b/benchmark/src/generation.rs @@ -148,6 +148,7 @@ async fn prefill( }), inputs: sequence.clone(), truncate: sequence_length, + add_special_tokens: true, parameters: Some(parameters.clone()), stopping_parameters: Some(StoppingCriteriaParameters { max_new_tokens: decode_length, diff --git a/flake.lock b/flake.lock index 14011768491..c0a696b1646 100644 --- a/flake.lock +++ b/flake.lock @@ -835,11 +835,11 @@ ] }, "locked": { - "lastModified": 1724206841, - "narHash": "sha256-L8dKaX4T3k+TR2fEHCfGbH4UXdspovz/pj87iai9qmc=", + "lastModified": 1724638882, + "narHash": "sha256-ap2jIQi/FuUHR6HCht6ASWhoz8EiB99XmI8Esot38VE=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "45e98fbd62c32e5927e952d2833fa1ba4fb35a61", + "rev": "19b70f147b9c67a759e35824b241f1ed92e46694", "type": "github" }, "original": { diff --git a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json index 8631c076041..5553e17dd80 100644 --- a/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json +++ b/integration-tests/models/__snapshots__/test_chat_llama/test_flash_llama_simple.json @@ -5,7 +5,7 @@ "index": 0, "logprobs": null, "message": { - "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas", + "content": "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C", "name": null, "role": "assistant", "tool_calls": null @@ -13,14 +13,14 @@ "usage": null } ], - "created": 1716553098, + "created": 1724792495, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", - "object": "text_completion", - "system_fingerprint": "2.0.5-dev0-native", + "object": "chat.completion", + "system_fingerprint": "2.2.1-dev0-native", "usage": { "completion_tokens": 100, - "prompt_tokens": 62, - "total_tokens": 162 + "prompt_tokens": 61, + "total_tokens": 161 } } diff --git a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json index d87071cfac5..e7fb5740030 100644 --- a/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json +++ b/integration-tests/models/__snapshots__/test_completion_prompts/test_flash_llama_completion_many_prompts_stream.json @@ -8,11 +8,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -23,11 +23,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -38,11 +38,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -53,11 +53,11 @@ "text": "hd" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -68,11 +68,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -83,11 +83,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -98,11 +98,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -113,11 +113,11 @@ "text": "aho" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -128,11 +128,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -143,11 +143,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -158,11 +158,11 @@ "text": "2" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -173,11 +173,11 @@ "text": "ima" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -188,11 +188,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -203,11 +203,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -218,11 +218,11 @@ "text": "." } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -233,11 +233,11 @@ "text": "\n" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -248,11 +248,11 @@ "text": " Sarah" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -263,11 +263,11 @@ "text": " Yes" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -278,11 +278,11 @@ "text": " And" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -293,11 +293,11 @@ "text": "i" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -308,11 +308,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -323,11 +323,11 @@ "text": "," } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -338,11 +338,11 @@ "text": " what" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -353,11 +353,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -368,11 +368,11 @@ "text": "s" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -383,11 +383,11 @@ "text": " Moh" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -398,11 +398,11 @@ "text": " is" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -413,11 +413,11 @@ "text": "m" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -428,11 +428,11 @@ "text": " Room" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -443,11 +443,11 @@ "text": "s" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -458,11 +458,11 @@ "text": " the" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -473,11 +473,11 @@ "text": " tired" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -488,11 +488,11 @@ "text": ":" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -503,11 +503,11 @@ "text": "'" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -518,11 +518,11 @@ "text": " capital" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ @@ -530,73 +530,73 @@ "finish_reason": "", "index": 3, "logprobs": null, - "text": " of" + "text": "," } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 0, "logprobs": null, "text": " She" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 1, "logprobs": null, "text": " scale" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 2, "logprobs": null, "text": " of" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" }, { "choices": [ { - "finish_reason": "", + "finish_reason": "length", "index": 3, "logprobs": null, - "text": " being" + "text": " its" } ], - "created": 1713284431, + "created": 1724833943, "id": "", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "object": "text_completion", - "system_fingerprint": "2.0.1-native" + "system_fingerprint": "2.2.1-dev0-native" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json index 03f90367280..732b0c499d5 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -16,7 +16,7 @@ }, { "id": 3102, - "logprob": -11.1875, + "logprob": -11.25, "text": " request" } ], @@ -24,66 +24,66 @@ "tokens": [ { "id": 185, - "logprob": -1.5546875, + "logprob": -1.546875, "special": false, "text": "\n" }, { "id": 549, - "logprob": -2.84375, + "logprob": -2.859375, "special": false, "text": "The" }, { "id": 1727, - "logprob": -2.34375, + "logprob": -2.484375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.8359375, + "logprob": -0.83203125, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.0859375, + "logprob": -1.1484375, "special": false, "text": " is" }, { - "id": 254, - "logprob": -1.5390625, + "id": 245, + "logprob": -1.578125, "special": false, - "text": " the" + "text": " a" }, { - "id": 1022, - "logprob": -1.1875, + "id": 3412, + "logprob": -2.578125, "special": false, - "text": " first" + "text": " document" }, { - "id": 3458, - "logprob": -0.35546875, + "id": 344, + "logprob": -1.125, "special": false, - "text": " step" + "text": " that" }, { - "id": 279, - "logprob": -0.8828125, + "id": 317, + "logprob": -1.6953125, "special": false, - "text": " in" + "text": " is" }, { - "id": 254, - "logprob": -0.71484375, + "id": 1222, + "logprob": -1.71875, "special": false, - "text": " the" + "text": " used" } ], "top_tokens": null }, - "generated_text": "\nThe test request is the first step in the" + "generated_text": "\nThe test request is a document that is used" } diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json index e365829a2a2..f1eeab25ca9 100644 --- a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -37,56 +37,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -126,56 +126,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -215,56 +215,56 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" }, { "details": { @@ -304,55 +304,55 @@ }, { "id": 1727, - "logprob": -2.359375, + "logprob": -2.4375, "special": false, "text": " test" }, { "id": 3102, - "logprob": -0.83203125, + "logprob": -0.83984375, "special": false, "text": " request" }, { "id": 317, - "logprob": -1.125, + "logprob": -1.1328125, "special": false, "text": " is" }, { - "id": 245, - "logprob": -1.5703125, + "id": 254, + "logprob": -1.515625, "special": false, - "text": " a" + "text": " the" }, { - "id": 3412, - "logprob": -2.578125, + "id": 1022, + "logprob": -1.15625, "special": false, - "text": " document" + "text": " first" }, { - "id": 344, - "logprob": -1.125, + "id": 3458, + "logprob": -0.3671875, "special": false, - "text": " that" + "text": " step" }, { - "id": 317, - "logprob": -1.6953125, + "id": 279, + "logprob": -0.88671875, "special": false, - "text": " is" + "text": " in" }, { - "id": 1222, - "logprob": -1.75, + "id": 254, + "logprob": -0.69140625, "special": false, - "text": " used" + "text": " the" } ], "top_tokens": null }, - "generated_text": "\nThe test request is a document that is used" + "generated_text": "\nThe test request is the first step in the" } ] diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json index bf981e4f169..e39829ece3b 100644 --- a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -1,8 +1,8 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 10, + "finish_reason": "stop_sequence", + "generated_tokens": 5, "prefill": [ { "id": 128000, @@ -16,7 +16,7 @@ }, { "id": 1715, - "logprob": -10.375, + "logprob": -10.4375, "text": " request" } ], @@ -29,61 +29,31 @@ "text": ":" }, { - "id": 2209, - "logprob": -2.78125, + "id": 923, + "logprob": -2.84375, "special": false, - "text": " Is" + "text": " add" }, { - "id": 279, - "logprob": -0.6328125, + "id": 264, + "logprob": 0.0, "special": false, - "text": " the" - }, - { - "id": 734, - "logprob": -2.703125, - "special": false, - "text": " function" + "text": " a" }, { "id": 330, - "logprob": -0.34179688, + "logprob": -0.31640625, "special": false, "text": " \"" }, { - "id": 4110, - "logprob": -2.359375, - "special": false, - "text": "Create" - }, - { - "id": 7575, - "logprob": -2.1875, - "special": false, - "text": "Process" - }, - { - "id": 1, - "logprob": -0.07910156, - "special": false, - "text": "\"" - }, - { - "id": 304, - "logprob": -0.83203125, - "special": false, - "text": " in" - }, - { - "id": 12468, - "logprob": -1.8203125, + "id": 1985, + "logprob": 0.0, "special": false, - "text": " Win" + "text": "test" } ], "top_tokens": null }, - "generated_text": "Test request: Is the function \"CreateProcess\" in Win" + "generated_text": "Test request: add a \"test" } diff --git a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json index d882b82ac47..412b19b49b9 100644 --- a/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json +++ b/integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json @@ -16,7 +16,7 @@ }, { "id": 100, - "logprob": -0.38549805, + "logprob": -0.38305664, "text": "_" }, { @@ -29,7 +29,7 @@ "tokens": [ { "id": 2284, - "logprob": -0.31323242, + "logprob": -0.296875, "special": false, "text": "():" }, @@ -59,19 +59,19 @@ }, { "id": 10914, - "logprob": -0.7817383, + "logprob": -0.7734375, "special": false, "text": " World" }, { "id": 16013, - "logprob": -0.6328125, + "logprob": -0.61816406, "special": false, "text": "!\")" }, { "id": 222, - "logprob": -0.0619812, + "logprob": -0.054870605, "special": false, "text": "\n" }, @@ -83,7 +83,7 @@ }, { "id": 610, - "logprob": -0.4086914, + "logprob": -0.4152832, "special": false, "text": "def" }, @@ -113,7 +113,7 @@ }, { "id": 444, - "logprob": -0.21826172, + "logprob": -0.21618652, "special": false, "text": "name" }, @@ -173,7 +173,7 @@ }, { "id": 11571, - "logprob": -0.10021973, + "logprob": -0.08892822, "special": false, "text": "!\"" }, diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json index 1fad0b96c14..dab437b9fa4 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_next_all_params.json @@ -30,19 +30,19 @@ }, { "id": 264, - "logprob": -0.37573242, + "logprob": -0.38061523, "special": false, "text": " a" }, { "id": 633, - "logprob": -0.09161377, + "logprob": -0.09301758, "special": false, "text": " new" }, { "id": 4480, - "logprob": -0.26171875, + "logprob": -0.26782227, "special": false, "text": " feature" }, @@ -78,7 +78,7 @@ }, { "id": 13, - "logprob": 0.0, + "logprob": -0.10632324, "special": false, "text": "\n" } diff --git a/integration-tests/models/test_chat_llama.py b/integration-tests/models/test_chat_llama.py index 1f7a4a59617..7d24add3633 100644 --- a/integration-tests/models/test_chat_llama.py +++ b/integration-tests/models/test_chat_llama.py @@ -35,6 +35,6 @@ async def test_flash_llama_simple(flash_llama_chat, response_snapshot): print(repr(response.choices[0].message.content)) assert ( response.choices[0].message.content - == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to explore in the middle of urban confines. In fact, typical times for humidity levels in Brooklyn include:\n\n- Early morning: 80-85% humidity, with occas" + == "As of your last question, the weather in Brooklyn, New York, is typically hot and humid throughout the year. The suburbs around New York City are jealously sheltered, and at least in the Lower Bronx, there are very few outdoor environments to appreciate nature.\n\nIn terms of temperature, the warmest times of the year are from June to August, when average high temperatures typically range from around 73°F or 23°C" ) assert response == response_snapshot diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 9a90a6733b8..8e5c9dcdd87 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -8,7 +8,7 @@ use nix::unistd::Pid; use serde::Deserialize; use std::env; use std::ffi::OsString; -use std::io::{BufRead, BufReader, Lines}; +use std::io::{BufRead, BufReader}; use std::os::unix::process::{CommandExt, ExitStatusExt}; use std::path::Path; use std::process::{Child, Command, ExitStatus, Stdio}; @@ -18,12 +18,103 @@ use std::sync::{mpsc, Arc}; use std::thread; use std::thread::sleep; use std::time::{Duration, Instant}; -use std::{fs, io}; +use std::{ + fs, io, + io::{Read, Write}, +}; use thiserror::Error; use tracing_subscriber::{filter::LevelFilter, EnvFilter}; mod env_runtime; +fn get_config( + model_id: &str, + revision: &Option, +) -> Result> { + let mut path = std::path::Path::new(model_id).to_path_buf(); + let model_id = model_id.to_string(); + let filename = if !path.exists() { + // Assume it's a hub id + + let api = if let Ok(token) = std::env::var("HF_TOKEN") { + // env variable has precedence over on file token. + ApiBuilder::new().with_token(Some(token)).build()? + } else { + Api::new()? + }; + let repo = if let Some(ref revision) = revision { + api.repo(Repo::with_revision( + model_id, + RepoType::Model, + revision.to_string(), + )) + } else { + api.model(model_id) + }; + repo.get("config.json")? + } else { + path.push("config.json"); + path + }; + + let content = std::fs::read_to_string(filename)?; + let config: RawConfig = serde_json::from_str(&content)?; + + let config: Config = config.into(); + Ok(config) +} + +fn resolve_attention(config: &Option, lora_adapters: &Option) -> (String, String) { + let mut prefix_caching: Option = std::env::var("USE_PREFIX_CACHING").ok(); + let mut attention: Option = std::env::var("ATTENTION").ok(); + if let Some(config) = config { + if prefix_caching.is_none() { + if config.vision_config.is_some() { + tracing::info!("Disabling prefix caching because of VLM model"); + prefix_caching = Some("0".to_string()); + } else if config.is_encoder_decoder { + tracing::info!("Disabling prefix caching because of seq2seq model"); + prefix_caching = Some("0".to_string()); + } + } + match config.head_dim { + Some(h) if h == 64 || h == 128 || h == 256 => { + if lora_adapters.is_some() && prefix_caching.is_none() { + tracing::info!("Disabling prefix caching because of lora adapters"); + prefix_caching = Some("0".to_string()); + } + match config.model_type.as_deref() { + Some("gemma2") | Some("falcon") | Some("deepseek_v2") => { + // Required because gemma2 needs bfloat16 which is not supported by + // flashinfer ? + if attention.is_none() { + tracing::info!( + "Forcing flash decoding because model {} requires it", + config.model_type.as_ref().unwrap() + ); + attention = Some("flashdecoding".to_string()); + } + } + Some("t5") => {} + _ => {} + } + } + _ => { + if attention.is_none() { + tracing::info!("Forcing flash decoding because head dim is not supported by flashinfer, also disabling prefix caching"); + attention = Some("flashdecoding".to_string()); + } + if prefix_caching.is_none() { + prefix_caching = Some("0".to_string()); + } + } + } + } + let prefix_caching = prefix_caching.unwrap_or("true".to_string()); + let attention = attention.unwrap_or("flashinfer".to_string()); + (prefix_caching, attention) +} + #[derive(Deserialize)] struct RawConfig { max_position_embeddings: Option, @@ -31,6 +122,12 @@ struct RawConfig { model_type: Option, max_seq_len: Option, quantization_config: Option, + n_embd: Option, + hidden_size: Option, + num_attention_heads: Option, + head_dim: Option, + vision_config: Option, + is_encoder_decoder: Option, } #[derive(Deserialize)] @@ -38,10 +135,17 @@ struct QuantizationConfig { quant_method: Option, } +#[derive(Deserialize)] +struct VisionConfig {} + #[derive(Deserialize)] struct Config { max_position_embeddings: Option, quantize: Option, + head_dim: Option, + model_type: Option, + vision_config: Option, + is_encoder_decoder: bool, } impl From for Config { @@ -51,9 +155,32 @@ impl From for Config { .or(other.max_seq_len) .or(other.n_positions); let quantize = other.quantization_config.and_then(|q| q.quant_method); + let head_dim = other.head_dim.or_else(|| { + match (other.hidden_size, other.n_embd, other.num_attention_heads) { + (Some(hidden_size), _, Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + // Legacy + (_, Some(hidden_size), Some(num_attention_heads)) + if hidden_size % num_attention_heads == 0 => + { + Some(hidden_size / num_attention_heads) + } + _ => None, + } + }); + let model_type = other.model_type; + let vision_config = other.vision_config; + let is_encoder_decoder = other.is_encoder_decoder.unwrap_or(false); Config { max_position_embeddings, quantize, + head_dim, + model_type, + vision_config, + is_encoder_decoder, } } } @@ -731,6 +858,7 @@ fn shard_manager( .args(shard_args) .env_clear() .envs(envs) + .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .process_group(0) @@ -752,12 +880,13 @@ fn shard_manager( }; // Redirect STDOUT to the console + let mut pstdin = p.stdin.take().unwrap(); let shard_stdout_reader = BufReader::new(p.stdout.take().unwrap()); let shard_stderr_reader = BufReader::new(p.stderr.take().unwrap()); //stdout tracing thread thread::spawn(move || { - log_lines(shard_stdout_reader.lines()); + log_lines(shard_stdout_reader); }); // We read stderr in another thread as it seems that lines() can block in some cases let (err_sender, err_receiver) = mpsc::channel(); @@ -766,6 +895,18 @@ fn shard_manager( err_sender.send(line).unwrap_or(()); } }); + // We read stdin in another thread as it seems that lines() can block in some cases + thread::spawn(move || { + let mut stdin = io::stdin(); // We get `Stdin` here. + loop { + let mut buffer = vec![0; 4096]; + if let Ok(n) = stdin.read(&mut buffer) { + if n > 0 { + let _ = pstdin.write_all(&buffer[..n]); + } + } + } + }); let mut ready = false; let start_time = Instant::now(); @@ -872,19 +1013,36 @@ impl PythonLogMessage { } } -impl TryFrom<&String> for PythonLogMessage { +impl TryFrom<&[u8]> for PythonLogMessage { type Error = serde_json::Error; - fn try_from(value: &String) -> Result { - serde_json::from_str::(value) + fn try_from(value: &[u8]) -> Result { + serde_json::from_slice::(value) } } -fn log_lines(lines: Lines) { - for line in lines.map_while(Result::ok) { - match PythonLogMessage::try_from(&line) { - Ok(log) => log.trace(), - Err(_) => tracing::debug!("{line}"), +fn log_lines(mut bufread: BufReader) { + let mut buffer = vec![0u8; 8 * 4096]; + let mut stdout = std::io::stdout(); + loop { + let n = bufread.read(&mut buffer); + if let Ok(n) = n { + if n > 0 { + let mut lines = buffer[..n].split(|i| *i == b'\n').peekable(); + while let Some(line) = lines.next() { + match PythonLogMessage::try_from(line) { + Ok(log) => log.trace(), + // For interactive debugging ? + Err(_) => { + stdout.write_all(line).unwrap(); + if lines.peek().is_some() { + stdout.write_all(b"\n").unwrap(); + } + stdout.flush().unwrap(); + } + } + } + } } } } @@ -1044,7 +1202,7 @@ fn download_convert_model( let download_stdout = BufReader::new(download_process.stdout.take().unwrap()); thread::spawn(move || { - log_lines(download_stdout.lines()); + log_lines(download_stdout); }); let download_stderr = BufReader::new(download_process.stderr.take().unwrap()); @@ -1439,68 +1597,35 @@ fn main() -> Result<(), LauncherError> { tracing::info!("{:#?}", args); - let get_max_positions_quantize = - || -> Result<(usize, Option), Box> { - let model_id = args.model_id.clone(); - let mut path = std::path::Path::new(&args.model_id).to_path_buf(); - let filename = if !path.exists() { - // Assume it's a hub id - - let api = if let Ok(token) = std::env::var("HF_TOKEN") { - // env variable has precedence over on file token. - ApiBuilder::new().with_token(Some(token)).build()? - } else { - Api::new()? - }; - let repo = if let Some(ref revision) = args.revision { - api.repo(Repo::with_revision( - model_id, - RepoType::Model, - revision.to_string(), - )) - } else { - api.model(model_id) - }; - repo.get("config.json")? - } else { - path.push("config.json"); - path - }; - - let content = std::fs::read_to_string(filename)?; - let config: RawConfig = serde_json::from_str(&content)?; - - if config.model_type == Some("gemma2".to_string()) { - tracing::info!("Forcing flash decoding because of softcap usage"); - std::env::set_var("ATTENTION", "flashdecoding"); - } - let config: Config = config.into(); - let quantize = config.quantize; - - // Quantization usually means you're even more RAM constrained. - let max_default = 4096; - - if let Some(max_position_embeddings) = config.max_position_embeddings { - if max_position_embeddings > max_default { - let max = max_position_embeddings; - if args.max_input_tokens.is_none() - && args.max_total_tokens.is_none() - && args.max_batch_prefill_tokens.is_none() - { - tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); - } - Ok((max_default, quantize)) - } else { - Ok((max_position_embeddings, quantize)) + let config: Option = get_config(&args.model_id, &args.revision).ok(); + let quantize = config.as_ref().and_then(|c| c.quantize); + // Quantization usually means you're even more RAM constrained. + let max_default = 4096; + + let max_position_embeddings = if let Some(config) = &config { + if let Some(max_position_embeddings) = config.max_position_embeddings { + if max_position_embeddings > max_default { + let max = max_position_embeddings; + if args.max_input_tokens.is_none() + && args.max_total_tokens.is_none() + && args.max_batch_prefill_tokens.is_none() + { + tracing::info!("Model supports up to {max} but tgi will now set its default to {max_default} instead. This is to save VRAM by refusing large prompts in order to allow more users on the same hardware. You can increase that size using `--max-batch-prefill-tokens={} --max-total-tokens={max} --max-input-tokens={}`.", max + 50, max - 1); } + max_default } else { - Err(Box::new(LauncherError::ArgumentValidation( - "no max defined".to_string(), - ))) + max_position_embeddings } - }; - let (max_position_embeddings, quantize): (usize, Option) = - get_max_positions_quantize().unwrap_or((4096, None)); + } else { + max_default + } + } else { + max_default + }; + let (prefix_caching, attention) = resolve_attention(&config, &args.lora_adapters); + tracing::info!("Using attention {attention} - Prefix caching {prefix_caching}"); + std::env::set_var("USE_PREFIX_CACHING", prefix_caching); + std::env::set_var("ATTENTION", attention); let max_input_tokens = { match (args.max_input_tokens, args.max_input_length) { diff --git a/load_tests/common.js b/load_tests/common.js index e0a105956e1..d890bf6710d 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -33,13 +33,13 @@ export function get_options() { // rate: 20, // timeUnit: '1s', // }, - load_test: { - executor: 'constant-arrival-rate', - duration: '60s', - preAllocatedVUs: 100, - rate: 1, - timeUnit: '1s', - }, + // load_test: { + // executor: 'constant-arrival-rate', + // duration: '60s', + // preAllocatedVUs: 100, + // rate: 1, + // timeUnit: '1s', + // }, // breakpoint: { // executor: 'ramping-arrival-rate', //Assure load increase if the system slows // preAllocatedVUs: 300, @@ -47,12 +47,12 @@ export function get_options() { // { duration: '60s', target: 100 }, // just slowly ramp-up to a HUGE load // ], // }, - // throughput: { - // executor: 'shared-iterations', - // vus: 100, - // iterations: 200, - // maxDuration: '40s', - // }, + throughput: { + executor: 'shared-iterations', + vus: 100, + iterations: 200, + maxDuration: '40s', + }, }, }; } diff --git a/proto/v3/generate.proto b/proto/v3/generate.proto index 68eea7ac9fa..34894bdaba4 100644 --- a/proto/v3/generate.proto +++ b/proto/v3/generate.proto @@ -137,6 +137,8 @@ message Request { optional string adapter_id = 11; /// Prefix length that can be retrieved from the KV cache. uint32 prefix_len = 12; + /// Context truncation + bool add_special_tokens = 13; } message Batch { diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index 81c0d38f6a7..240282d97e0 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -120,10 +120,11 @@ impl Infer { ) -> Result, InferError> { // Tokenize request let inputs = request.inputs; + let add_special_tokens = request.add_special_tokens; let truncate = request.parameters.truncate; let encoding = self .validation - .tokenize(inputs, truncate) + .tokenize(inputs, add_special_tokens, truncate) .await .map_err(|err| { tracing::error!("Tokenization {err}"); diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c46754..979f6dd1be9 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -22,6 +22,16 @@ pub enum Attention { FlashInfer, } +impl Attention { + pub fn block_size(&self) -> u32 { + match self { + Attention::FlashDecoding => 256, + Attention::FlashInfer => 1, + Attention::Paged => 16, + } + } +} + #[derive(Debug)] pub struct ParseError; @@ -1072,6 +1082,16 @@ pub(crate) struct GenerateRequest { pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, + + /// This is used internally because some requests + /// already contain the templated input therefore + /// we shouldn't add the special tokens. + #[serde(default = "default_true", skip)] + pub add_special_tokens: bool, +} + +fn default_true() -> bool { + true } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1089,6 +1109,7 @@ impl From for GenerateRequest { fn from(req: CompatGenerateRequest) -> Self { Self { inputs: req.inputs, + add_special_tokens: true, parameters: req.parameters, } } diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a3316d..f273a786993 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -158,6 +158,7 @@ async fn get_chat_tokenize( let generate_request = GenerateRequest { inputs, + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -754,6 +755,7 @@ async fn completions( .iter() .map(|prompt| GenerateRequest { inputs: prompt.to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: None, temperature, @@ -1180,6 +1182,7 @@ async fn chat_completions( // build the request passing some parameters let generate_request = GenerateRequest { inputs: inputs.to_string(), + add_special_tokens: false, parameters: GenerateParameters { best_of: None, temperature, @@ -1386,6 +1389,7 @@ async fn vertex_compatibility( .map(|instance| { let generate_request = GenerateRequest { inputs: instance.inputs.clone(), + add_special_tokens: true, parameters: GenerateParameters { do_sample: true, max_new_tokens: instance.parameters.as_ref().and_then(|p| p.max_new_tokens), diff --git a/router/src/validation.rs b/router/src/validation.rs index 0024723c688..92491d88fb2 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -95,6 +95,7 @@ impl Validation { pub async fn tokenize( &self, inputs: String, + add_special_tokens: bool, truncate: Option, ) -> Result)>, ValidationError> { // If we have a fast tokenizer @@ -104,7 +105,11 @@ impl Validation { // Send request to the background validation task // Unwrap is safe here sender - .send(((inputs, truncate), response_sender, Span::current())) + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) .unwrap(); // Await on response channel @@ -121,11 +126,15 @@ impl Validation { async fn validate_input( &self, inputs: String, + add_special_tokens: bool, truncate: Option, max_new_tokens: Option, ) -> Result<(Vec, Option>, usize, u32), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? { + if let Some((encoding, inputs)) = self + .tokenize(inputs.clone(), add_special_tokens, truncate) + .await? + { // Create response channel let input_length = if let Some(truncate) = truncate { std::cmp::min(encoding.len(), truncate) @@ -158,7 +167,8 @@ impl Validation { )); } - let input_ids = encoding.get_ids()[..input_length].to_owned(); + let ids = encoding.get_ids(); + let input_ids = ids[ids.len().saturating_sub(input_length)..].to_owned(); metrics::histogram!("tgi_request_input_length").record(input_length as f64); Ok((inputs, Some(input_ids), input_length, max_new_tokens)) @@ -324,7 +334,12 @@ impl Validation { // Validate inputs let (inputs, input_ids, input_length, max_new_tokens) = self - .validate_input(request.inputs, truncate, max_new_tokens) + .validate_input( + request.inputs, + request.add_special_tokens, + truncate, + max_new_tokens, + ) .await?; // TODO: we should build the FSM here and pass the compiled FSM instead of the grammar @@ -401,6 +416,7 @@ impl Validation { Ok(ValidGenerateRequest { inputs, input_ids: input_ids.map(Arc::new), + add_special_tokens: request.add_special_tokens, decoder_input_details, input_length: input_length as u32, truncate: truncate.unwrap_or(self.max_input_length) as u32, @@ -449,12 +465,15 @@ fn tokenizer_worker( mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, + add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), @@ -591,6 +610,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { fn prepare_input( inputs: String, _truncate: Option, + add_special_tokens: bool, tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, @@ -628,14 +648,14 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, true) + .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks)) } type TokenizerRequest = ( - (String, Option), + (String, bool, Option), oneshot::Sender), ValidationError>>, Span, ); @@ -720,6 +740,7 @@ pub struct ValidGenerateRequest { pub input_ids: Option>>, pub input_length: u32, pub truncate: u32, + pub add_special_tokens: bool, pub decoder_input_details: bool, pub parameters: ValidParameters, pub stopping_parameters: ValidStoppingParameters, @@ -826,7 +847,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { // Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -861,7 +882,7 @@ mod tests { let max_new_tokens = 10; match validation - .validate_input("Hello".to_string(), None, Some(max_new_tokens)) + .validate_input("Hello".to_string(), true, None, Some(max_new_tokens)) .await { Err(ValidationError::MaxTotalTokens(6, 1, 10)) => (), @@ -895,6 +916,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { best_of: Some(2), do_sample: false, @@ -934,6 +956,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(1.0), max_new_tokens: Some(5), @@ -949,6 +972,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: Some(0.99), max_new_tokens: Some(5), @@ -964,6 +988,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_p: None, max_new_tokens: Some(5), @@ -1002,6 +1027,7 @@ mod tests { match validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(5), max_new_tokens: Some(5), @@ -1017,6 +1043,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(4), max_new_tokens: Some(5), @@ -1029,6 +1056,7 @@ mod tests { validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: Some(0), max_new_tokens: Some(5), @@ -1041,6 +1069,7 @@ mod tests { let valid_request = validation .validate(GenerateRequest { inputs: "Hello".to_string(), + add_special_tokens: true, parameters: GenerateParameters { top_n_tokens: None, max_new_tokens: Some(5), @@ -1089,6 +1118,7 @@ mod tests { let chunks = match validation .tokenize( format!("test![](data:image/gif;base64,{})", PIXEL_GIF), + true, None, ) .await @@ -1148,6 +1178,7 @@ mod tests { "test![](data:image/gif;base64,{})![](data:image/gif;base64,{})", PIXEL_GIF, PIXEL_GIF ), + true, None, ) .await diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8c77896e9e0..f392b161c58 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] # Released on: June 13, 2024 # https://releases.rs/docs/1.79.0/ -channel = "1.79.0" +channel = "1.80.0" components = ["rustfmt", "clippy"] diff --git a/server/Makefile b/server/Makefile index 51ea8b32206..9338b299090 100644 --- a/server/Makefile +++ b/server/Makefile @@ -7,6 +7,7 @@ include Makefile-selective-scan include Makefile-lorax-punica include Makefile-fbgemm include Makefile-exllamav2 +include Makefile-flashinfer unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-flashinfer b/server/Makefile-flashinfer new file mode 100644 index 00000000000..3abb0491827 --- /dev/null +++ b/server/Makefile-flashinfer @@ -0,0 +1,2 @@ +install-flashinfer: + pip install flashinfer==0.1.5 -i https://flashinfer.ai/whl/cu124/torch2.4 diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 16d2c4081ff..d99771f8ad0 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,7 +1,10 @@ import pytest - +import os from text_generation_server.pb import generate_pb2 +os.environ["USE_PREFIX_CACHING"] = "1" +os.environ["ATTENTION"] = "flashinfer" + @pytest.fixture def default_pb_parameters(): diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py index f162230cfb2..855f4dfc0f6 100644 --- a/server/text_generation_server/layers/attention/common.py +++ b/server/text_generation_server/layers/attention/common.py @@ -9,26 +9,46 @@ @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor cu_seqlen_q: Optional[torch.Tensor] cu_seqlen_k: Optional[torch.Tensor] + max_q: int + max_k: int - def __init__(self, input_lengths): + def __init__( + self, + input_lengths, + prefix_lengths, + cu_seqlen_q=None, + max_q=None, + max_k=None, + ): self.input_lengths = input_lengths + self.prefix_lengths = prefix_lengths device = self.input_lengths.device shape = self.input_lengths.shape - cu_seqlen_q = torch.arange( - shape[0] + 1, - device=device, - dtype=torch.int32, - ) + if cu_seqlen_q is None: + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + max_q = 1 + else: + assert max_q is not None + assert max_k is not None cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32) + # cuda graphs don't like this and this is necessary to clamp within mistral # Although FA2 might not want the clamping # cu_seqlen_k[0] = 0 - torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + total = self.input_lengths + self.prefix_lengths + torch.cumsum(total, -1, out=cu_seqlen_k[1:]) self.cu_seqlen_q = cu_seqlen_q self.cu_seqlen_k = cu_seqlen_k + self.max_q = max_q + self.max_k = max_k def clamp(self, max): # Flash decoding doesn't need to clamp @@ -39,6 +59,11 @@ def clamp(self, max): @dataclass class Seqlen: input_lengths: torch.Tensor + prefix_lengths: torch.Tensor + cu_seqlen_q: torch.Tensor + max_q: int + max_k: int def clamp(self, max): + raise NotImplementedError("Not implemented seqlen for paged") return Seqlen(torch.clamp(self.input_lengths, max=max)) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index b3b7ea4fe0c..4b588b5cf40 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -222,18 +222,15 @@ def paged_attention( def attention( q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, softcap=0.0, ): - assert window_size_left == -1, "Windowing is not supported with flash infer" from text_generation_server.layers.attention.flashinfer import ( prefill_with_paged_kv_state, ) @@ -244,18 +241,17 @@ def attention( paged_kv_cache=(key_cache, value_cache), logits_soft_cap=softcap, sm_scale=softmax_scale, + window_left=window_size_left, ) elif V2: def attention( q, - k, - v, key_cache: torch.Tensor, value_cache: torch.Tensor, - cu_seqlens, - max_s, + seqlen: Seqlen, + block_tables: torch.Tensor, softmax_scale, window_size_left=-1, causal=True, @@ -266,17 +262,17 @@ def attention( raise ValueError("`window_size_left` must be > 0 or -1") return flash_attn_2_cuda.varlen_fwd( q, - k, - v, + key_cache, + value_cache, out, - cu_seqlens, - cu_seqlens, - None, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, None, None, + block_tables, None, - max_s, - max_s, + seqlen.max_q, + seqlen.max_k, 0.0, softmax_scale, False, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 4fa9e66dcf0..e03cc30dc61 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -497,15 +497,14 @@ def get_model( else -1 ) - should_use_sliding_window = ( - sliding_window is not None and sliding_window != -1 and SUPPORTS_WINDOWING + use_sliding_window = sliding_window is not None and sliding_window != -1 + needs_sliding_window = ( + max_input_tokens is not None and max_input_tokens > sliding_window ) - - if should_use_sliding_window: - if max_input_tokens is not None and max_input_tokens > sliding_window: - raise ValueError( - f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." - ) + if use_sliding_window and needs_sliding_window and not SUPPORTS_WINDOWING: + raise ValueError( + f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." + ) if model_type == DEEPSEEK_V2: if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 1eb8c6c314b..fe19180a240 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -29,6 +29,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers import ( @@ -264,7 +265,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -296,12 +297,10 @@ def forward( # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -313,7 +312,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -388,7 +387,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -402,7 +401,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -454,7 +453,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: torch.Tensor, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -477,7 +476,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -518,7 +517,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -531,7 +530,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index fc0dca5bfaf..b82b5473d58 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -29,6 +29,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -309,7 +310,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -335,12 +336,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -352,7 +351,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -389,7 +388,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.norm_1(hidden_states, residual) @@ -403,7 +402,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -622,7 +621,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Self Attention @@ -635,7 +634,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -679,7 +678,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -701,7 +700,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -734,7 +733,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -747,7 +746,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index b25becd5c7d..0585b40e6c5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -29,8 +29,8 @@ attention, paged_attention, reshape_and_cache, + Seqlen, ) -from text_generation_server.layers.attention.common import Seqlen from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM @@ -298,7 +298,7 @@ def forward( kv_cache: Tuple[torch.Tensor, torch.Tensor], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): if self.q_lora_rank is None: @@ -363,12 +363,10 @@ def forward( # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -380,7 +378,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -666,7 +664,7 @@ def forward( kv_cache, block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: Seqlen, + seqlen: Seqlen, max_s: int, ): normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -680,7 +678,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -729,7 +727,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -751,7 +749,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -781,7 +779,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -794,7 +792,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index faf0f32587c..d16e805f6b7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -30,6 +30,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -236,12 +237,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, window_size_left=self.window_size, @@ -256,7 +255,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, softcap=self.softcap, ) @@ -343,7 +342,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -357,7 +356,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -408,7 +407,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -430,7 +429,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -477,7 +476,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -491,7 +490,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 33738a59d93..34be4cb8267 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -30,6 +30,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -207,7 +208,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -230,12 +231,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, causal=self.causal, ) @@ -248,7 +247,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -320,7 +319,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -334,7 +333,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -404,7 +403,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -449,7 +448,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -463,7 +462,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index d30b5a0ab01..403fa90843b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -29,6 +29,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -213,7 +214,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -230,12 +231,10 @@ def forward( # flash attention attn_output = attention( query, - key, - value, kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -247,7 +246,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -316,7 +315,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): residual = hidden_states @@ -329,7 +328,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +381,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -398,7 +397,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -435,7 +434,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -451,7 +450,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py index eb667384343..35ab2791f3c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gptj_modeling.py @@ -29,6 +29,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -167,7 +168,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( @@ -192,10 +193,10 @@ def forward( # flash attention attn_output = attention( query, - key, - value, - cu_seqlen_prefill, - max_s, + kv_cache[0], + kv_cache[1], + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -207,7 +208,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -268,7 +269,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.input_layernorm(hidden_states, residual) @@ -281,7 +282,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -328,7 +329,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: @@ -351,7 +352,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -382,7 +383,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -395,7 +396,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices=prefill_cache_indices, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 3253d2dc07c..5b228f9ffe6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,6 +32,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -194,7 +195,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -218,12 +219,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -235,7 +234,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -375,7 +374,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ): @@ -390,7 +389,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -479,7 +478,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -504,7 +503,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, adapter_data, ) @@ -548,7 +547,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -562,7 +561,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 5a150267d91..30ca3fafcb5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -31,6 +31,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -185,7 +186,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -217,12 +218,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -235,7 +234,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -356,7 +355,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -372,7 +371,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -424,7 +423,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -448,7 +447,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, adapter_data, @@ -499,7 +498,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -512,7 +511,7 @@ def forward( elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) inputs_embeds = self.embed_tokens(input_ids) hidden_states = self.model( @@ -522,7 +521,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index ad426ffe75a..c5d60af1ddf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -35,6 +35,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( FastLinear, @@ -243,7 +244,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -274,12 +275,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -292,7 +291,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -498,7 +497,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -513,7 +512,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -568,7 +567,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -592,7 +591,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -627,7 +626,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -640,7 +639,7 @@ def forward( elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -649,7 +648,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b684e035f53..fda648f9307 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -31,6 +31,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -147,7 +148,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -171,12 +172,10 @@ def forward( # flash attention attn_output = attention( qkv[:, 0], - qkv[:, 1], - qkv[:, 2], kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -188,7 +187,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -258,7 +257,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.use_parallel_residual: @@ -272,7 +271,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -350,7 +349,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) @@ -372,7 +371,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -404,7 +403,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -417,7 +416,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index e08a2aade77..d044b492626 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,6 +19,7 @@ from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -70,7 +71,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -107,7 +108,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index efe27c137c8..37adb8be1d2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -10,6 +10,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -159,7 +160,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Compute query, key, value and split @@ -192,12 +193,10 @@ def forward( if cu_seqlen_prefill is not None: attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -209,7 +208,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -276,7 +275,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -289,7 +288,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -341,7 +340,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -363,7 +362,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -396,7 +395,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -409,7 +408,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 879b8abd79c..5aac28a3057 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -9,6 +9,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -104,7 +105,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -135,12 +136,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -153,7 +152,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -225,7 +224,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,7 +239,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -296,7 +295,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -320,7 +319,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -361,7 +360,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -374,7 +373,7 @@ def forward( elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -383,7 +382,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index c72a9b90b7f..1c55dd91951 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -19,6 +19,7 @@ attention, paged_attention, reshape_and_cache, + Seqlen, ) @@ -181,7 +182,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -206,12 +207,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -223,7 +222,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -296,7 +295,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.query_key_value(hidden_states) @@ -343,7 +342,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -429,7 +428,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): if self.parallel_attn: @@ -443,7 +442,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -465,7 +464,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -552,7 +551,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): # Layer norm. @@ -567,7 +566,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -628,7 +627,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) @@ -650,7 +649,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -680,7 +679,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -693,7 +692,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 109304be904..19025c4c5a2 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,6 +9,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -268,7 +269,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): qkv = self.c_attn(hidden_states) @@ -291,12 +292,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(key_value, dim=1, index=0), - torch.select(key_value, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, ) # Decode @@ -308,7 +307,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -373,7 +372,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) @@ -383,7 +382,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -437,7 +436,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) @@ -454,7 +453,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, ) @@ -486,7 +485,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -499,7 +498,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, ) if lm_head_indices is not None: diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 200d4ef0c19..2f9ecd0de92 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -30,6 +30,7 @@ paged_attention, attention, reshape_and_cache, + Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, @@ -209,7 +210,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -240,12 +241,10 @@ def forward( # flash attention attn_output = attention( query, - torch.select(kv, dim=1, index=0), - torch.select(kv, dim=1, index=1), kv_cache[0], kv_cache[1], - cu_seqlen_prefill, - max_s, + seqlen, + block_tables, self.softmax_scale, window_size_left=self.max_past, ) @@ -258,7 +257,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - input_lengths, + seqlen, max_s, ) @@ -381,7 +380,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ): @@ -396,7 +395,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -449,7 +448,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], @@ -473,7 +472,7 @@ def forward( kv_cache[i], block_tables, slots, - input_lengths, + seqlen, max_s, prefill_cache_indices, ) @@ -521,7 +520,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -534,7 +533,7 @@ def forward( elif self.max_past is not None: # Clamp in decode mode as paged attention requires clamped values whereas the flash attention # kernel requires the true values - input_lengths = input_lengths.clamp(max=self.max_past_tensor) + seqlen = seqlen.clamp(max=self.max_past_tensor) hidden_states = self.model( input_ids, @@ -543,7 +542,7 @@ def forward( kv_cache, block_tables, slots, - input_lengths, + seqlen, max_s, true_max_s, prefill_cache_indices, diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 7e4deaf887a..a829c374128 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,6 +25,7 @@ from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) +from text_generation_server.layers.attention import Seqlen from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -740,7 +741,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -826,7 +827,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 29f5b9c7154..32e9d3348b3 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -23,6 +23,7 @@ from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -170,7 +171,7 @@ def forward( kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -276,7 +277,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index dd4203e068e..9a60d06ccb7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -43,7 +43,7 @@ ATTENTION, BLOCK_SIZE, CUDA_GRAPHS, - PREFIX_CACHING, + TGI_WIGGLE_ROOM, get_adapter_to_index, ) from text_generation_server.layers.attention import Seqlen @@ -189,16 +189,21 @@ def to_pb(self) -> generate_pb2.CachedBatch: def batch_tokenized_inputs( cls, requests: Iterable[generate_pb2.Request], tokenizer ): - batch_inputs = [] - max_truncation = 0 + max_length = 0 + all_input_ids = [] + batch_size = 0 for r in requests: - batch_inputs.append(concat_text_chunks(r.input_chunks.chunks)) - max_truncation = max(max_truncation, r.truncate) - - batch_tokenized_inputs = tokenizer( - batch_inputs, truncation=True, max_length=max_truncation - )["input_ids"] - return batch_tokenized_inputs + batch_size += 1 + inputs = concat_text_chunks(r.input_chunks.chunks) + input_ids = tokenizer( + inputs, + truncation=True, + max_length=r.truncate, + add_special_tokens=r.add_special_tokens, + )["input_ids"] + max_length = max(max_length, len(input_ids)) + all_input_ids.append(input_ids) + return all_input_ids @classmethod def from_tokenized( @@ -257,22 +262,15 @@ def from_tokenized( # request id -> idx in list mapping requests_idx_mapping[r.id] = i - tokenized_input = tokenized_input[-r.truncate :] - if ( - tokenized_input[0] == tokenizer.bos_token_id - and tokenized_input[1] == tokenizer.bos_token_id - ): - tokenized_input = tokenized_input[1:] - orig_input_length = len(tokenized_input) - if PREFIX_CACHING: - prefix_len = r.prefix_len - if prefix_len == orig_input_length: - assert prefix_len > 0 - prefix_len -= 1 - else: - prefix_len = 0 + prefix_len = r.prefix_len + assert ( + prefix_len <= orig_input_length + ), f"Prefix {prefix_len} vs input {orig_input_length}" + if prefix_len == orig_input_length: + assert prefix_len > 0 + prefix_len -= 1 prefix_ids.append(tokenized_input[:prefix_len]) tokenized_input = tokenized_input[prefix_len:] @@ -998,7 +996,7 @@ def __init__( config.sliding_window = None self.num_layers = config.num_hidden_layers - self.num_heads = config.num_attention_heads + self.num_heads = config.num_attention_heads // self.process_group.size() # Validation is done in the model itself if num_kv_heads is None: num_kv_heads = getattr(config, "num_key_value_heads", None) @@ -1160,8 +1158,15 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "block_tables": block_tables, "slots": slots, "input_lengths": input_lengths_tensor, + "prefix_lengths": prefix_lengths_tensor, } - input_lengths_ = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1204,7 +1209,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1213,7 +1218,13 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - input_lengths_tensor = Seqlen(input_lengths=input_lengths_tensor) + seqlen = Seqlen( + input_lengths=input_lengths_tensor, + prefix_lengths=prefix_lengths_tensor, + cu_seqlen_q=None, + max_q=1, + max_k=max_s, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1221,7 +1232,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): kv_cache=self.kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths_tensor, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=None, lm_head_indices=None, @@ -1268,7 +1279,7 @@ def warmup(self, batch: FlashCausalLMBatch): num_blocks = ( # Leave 5% for some wiggle room - int((free_memory * 0.95) // total_cache_size) + int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. + batch_num_blocks ) @@ -1360,18 +1371,26 @@ def tunableop_warmup(self, seqlen: int): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) - input_lengths = Seqlen(input_lengths=input_lengths) + prefix_lens_tensor = torch.zeros(seqlen, dtype=torch.int32, device=self.device) + cu_seqlen_prefill = torch.tensor( + [0, seqlen], device=self.device, dtype=torch.int32 + ) + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=1, + max_k=seqlen, + ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, position_ids=position_ids, - cu_seqlen_prefill=torch.tensor( - [0, seqlen], device=self.device, dtype=torch.int32 - ), + cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, block_tables=None, - input_lengths=input_lengths, + seqlen=seqlen, slots=slots, max_s=seqlen, lm_head_indices=None, @@ -1451,8 +1470,7 @@ def forward( cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: - input_lengths = input_lengths + prefix_lens_tensor - if PREFIX_CACHING: + if ATTENTION == "flashinfer": block_tables = block_tables_to_ragged( block_tables=block_tables, input_lengths=batch.input_lengths, @@ -1462,11 +1480,18 @@ def forward( block_tables=block_tables, cu_seqlen_prefill=cu_seqlen_prefill, input_lengths=batch.input_lengths, - input_lengths_tensor=input_lengths, + input_lengths_tensor=input_lengths + prefix_lens_tensor, prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -1474,7 +1499,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices, diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index d5133f5e217..6c518c2caa5 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -5,19 +5,22 @@ from text_generation_server.utils.log import log_master -PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING", "0").lower() in {"1", "true"} +PREFIX_CACHING = os.getenv("USE_PREFIX_CACHING").lower() in {"1", "true"} log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") -ATTENTION = os.getenv("ATTENTION", "flashinfer" if PREFIX_CACHING else "paged") +ATTENTION = os.getenv("ATTENTION") _expected = {"paged", "flashdecoding", "flashinfer"} assert ( ATTENTION in _expected ), f"Attention is not valid {ATTENTION}, expected {_expected}" log_master(logger.info, f"Using Attention = {ATTENTION}") -if PREFIX_CACHING and ATTENTION != "flashinfer": +if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}: raise RuntimeError("Prefix caching is only supported with flashinfer") MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None +TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.95")) +assert TGI_WIGGLE_ROOM > 0 +assert TGI_WIGGLE_ROOM < 1 # This is overridden by the cli BLOCK_SIZE: int diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2ed1a119d9d..d6cb36fab9b 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -372,7 +372,14 @@ def forward( prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -380,7 +387,7 @@ def forward( kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices,