From 1d1b1efa01170d10c16b11fe65875f8751100bd7 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Thu, 18 Jul 2024 14:00:13 +0000 Subject: [PATCH 01/14] fix(server): fix cohere (#2249) --- .../models/custom_modeling/flash_cohere_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 25719b999dc..49f5a81bc5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -259,8 +259,8 @@ def forward( cu_seqlen_prefill, kv_cache, block_tables, - input_lengths, slots, + input_lengths, max_s, ): qkv = self.query_key_value(hidden_states) From ba291dad9f5cd0061e5ee4cab332db42c434e369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 09:37:39 +0200 Subject: [PATCH 02/14] Improve the handling of quantized weights (#2250) * Improve the handling of quantized weights Handling of quantized weights was split between two mechanisms: - For quantized checkpoints, we used the new weight loader infrastructure. - For quantization while loading (EETQ, FP8, bitsandbytes) we instead relied on conditional in `get_linear`. Weight loaders support context managers to selectively load particular layers with different weight loaders, which is useful for models like Idefics2 AWQ, which uses a quantized text model, but unquantized vision and connector models. However, the context manager would be overrided by `get_linear`, which string-checks `quantizer`. Also, the context manager would not work with EETQ, FP8, and bitsandbytes. This change migrates all quantizers to the weight loader infrastructure. This has several benefits: - We can use context managers with all quantizers. - All the implementation details move down to the quantizer layers, `get_linear` does not need to know how to handle quantizer linear layers. - All quantizer weights are strongly typed, we don't pass around raw tensors. - We don't have to pass around the `quantizer` string everywhere. * Exclude non-MLP layers when using FP8 quantization with Llama --- .../test_flash_llama_marlin.json | 89 +++++ .../test_flash_llama_marlin24_all_params.json | 89 +++++ .../test_flash_llama_marlin24_load.json | 358 ++++++++++++++++++ .../models/test_completion_prompts.py | 2 + .../models/test_flash_llama_marlin_24.py | 66 ++++ server/tests/utils/test_weights.py | 22 +- server/text_generation_server/layers/bnb.py | 31 +- server/text_generation_server/layers/eetq.py | 18 + server/text_generation_server/layers/exl2.py | 13 +- server/text_generation_server/layers/fp8.py | 11 +- .../layers/gptq/__init__.py | 68 +++- .../text_generation_server/layers/linear.py | 172 +-------- .../text_generation_server/layers/marlin.py | 27 +- .../layers/tensor_parallel.py | 14 +- .../custom_modeling/flash_cohere_modeling.py | 4 +- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_gemma2_modeling.py | 4 +- .../custom_modeling/flash_gemma_modeling.py | 4 +- .../custom_modeling/flash_gpt2_modeling.py | 8 +- .../custom_modeling/flash_llama_modeling.py | 62 ++- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 2 +- .../flash_santacoder_modeling.py | 9 +- .../flash_starcoder2_modeling.py | 4 +- .../models/custom_modeling/idefics2.py | 20 +- .../models/custom_modeling/mpt_modeling.py | 2 +- .../utils/quantization.py | 35 +- .../text_generation_server/utils/weights.py | 51 ++- 30 files changed, 936 insertions(+), 265 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json create mode 100644 integration-tests/models/test_flash_llama_marlin_24.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json new file mode 100644 index 00000000000..94883de5fc6 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json new file mode 100644 index 00000000000..58cacb80298 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 5229, + "logprob": -0.6645508, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": 0.0, + "special": false, + "text": ":" + }, + { + "id": 6527, + "logprob": -2.2324219, + "special": false, + "text": " Could" + }, + { + "id": 451, + "logprob": 0.0, + "special": false, + "text": " not" + }, + { + "id": 6088, + "logprob": -1.6074219, + "special": false, + "text": " parse" + }, + { + "id": 1243, + "logprob": -1.6298828, + "special": false, + "text": " test" + }, + { + "id": 1206, + "logprob": -0.72558594, + "special": false, + "text": " case" + }, + { + "id": 1024, + "logprob": -0.40429688, + "special": false, + "text": " name" + }, + { + "id": 515, + "logprob": 0.0, + "special": false, + "text": " from" + }, + { + "id": 525, + "logprob": -1.2519531, + "special": false, + "text": " '" + } + ], + "top_tokens": null + }, + "generated_text": "Test request failed: Could not parse test case name from '" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json new file mode 100644 index 00000000000..96a40fa42ed --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_marlin_24/test_flash_llama_marlin24_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1, + "logprob": null, + "text": "" + }, + { + "id": 4321, + "logprob": -9.0859375, + "text": "Test" + }, + { + "id": 2009, + "logprob": -16.359375, + "text": "request" + } + ], + "seed": null, + "tokens": [ + { + "id": 5229, + "logprob": -2.7988281, + "special": false, + "text": " failed" + }, + { + "id": 29901, + "logprob": -0.91259766, + "special": false, + "text": ":" + }, + { + "id": 853, + "logprob": -2.8496094, + "special": false, + "text": " Un" + }, + { + "id": 23765, + "logprob": -1.1894531, + "special": false, + "text": "supported" + }, + { + "id": 4714, + "logprob": -1.5917969, + "special": false, + "text": " browser" + }, + { + "id": 29892, + "logprob": -0.34765625, + "special": false, + "text": "," + }, + { + "id": 1873, + "logprob": -1.2695312, + "special": false, + "text": " version" + }, + { + "id": 470, + "logprob": -0.25170898, + "special": false, + "text": " or" + }, + { + "id": 7481, + "logprob": -0.21411133, + "special": false, + "text": " platform" + }, + { + "id": 13, + "logprob": -1.1162109, + "special": false, + "text": "\n" + } + ], + "top_tokens": null + }, + "generated_text": " failed: Unsupported browser, version or platform\n" + } +] diff --git a/integration-tests/models/test_completion_prompts.py b/integration-tests/models/test_completion_prompts.py index 0efb6693862..d787873b066 100644 --- a/integration-tests/models/test_completion_prompts.py +++ b/integration-tests/models/test_completion_prompts.py @@ -100,6 +100,8 @@ async def test_flash_llama_completion_many_prompts_stream( chunk = [c.replace("data:", "") for c in chunk] # remove empty strings chunk = [c for c in chunk if c] + # remove completion marking chunk + chunk = [c for c in chunk if c != " [DONE]"] # parse json chunk = [json.loads(c) for c in chunk] diff --git a/integration-tests/models/test_flash_llama_marlin_24.py b/integration-tests/models/test_flash_llama_marlin_24.py new file mode 100644 index 00000000000..3eb94f02e18 --- /dev/null +++ b/integration-tests/models/test_flash_llama_marlin_24.py @@ -0,0 +1,66 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_marlin24_handle(launcher): + with launcher( + "nm-testing/Llama-2-7b-pruned2.4-Marlin_24", quantize="marlin" + ) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_marlin(flash_llama_marlin24_handle): + await flash_llama_marlin24_handle.health(300) + return flash_llama_marlin24_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snapshot): + response = await flash_llama_marlin.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_marlin24_load( + flash_llama_marlin, generate_load, response_snapshot +): + responses = await generate_load( + flash_llama_marlin, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/tests/utils/test_weights.py b/server/tests/utils/test_weights.py index 36b27be8cde..d2d2b76e51d 100644 --- a/server/tests/utils/test_weights.py +++ b/server/tests/utils/test_weights.py @@ -2,6 +2,7 @@ import torch from text_generation_server.utils.weights import ( DefaultWeightsLoader, + UnquantizedWeight, Weights, WeightsLoader, ) @@ -363,7 +364,10 @@ def __init__( self.process_group = process_group self.prefix = prefix self.weights_loader = ( - DefaultWeightsLoader() if weights_loader is None else weights_loader + # We don't need to get linear layers, so just wrap raw tensors. + DefaultWeightsLoader(lambda x: x) + if weights_loader is None + else weights_loader ) self._handles = {} @@ -632,6 +636,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -641,6 +646,7 @@ def test_get_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -669,6 +675,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -678,6 +685,7 @@ def test_get_weights_col_gtpq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -774,6 +782,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -783,6 +792,7 @@ def test_get_weights_col_packed_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -851,6 +861,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -860,6 +871,7 @@ def test_get_weights_col_packed_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -922,6 +934,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -931,6 +944,7 @@ def test_get_multi_weights_col_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -983,6 +997,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -992,6 +1007,7 @@ def test_get_multi_weights_col_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1051,6 +1067,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): g_idx=None, bits=8.0, groupsize=2.0, + use_awq_kernel=True, use_exllama=False, ) @@ -1060,6 +1077,7 @@ def test_get_weights_row_awq(gptq_weights_loader_awq): assert w.g_idx == expected_weight.g_idx, "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" @@ -1125,6 +1143,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): g_idx=torch.tensor([0, 1, 0, 1], dtype=torch.int32), bits=8.0, groupsize=2.0, + use_awq_kernel=False, use_exllama=False, ) @@ -1134,6 +1153,7 @@ def test_get_weights_row_gptq(gptq_weights_loader): assert torch.allclose(w.g_idx, expected_weight.g_idx), "g_idx mismatch" assert w.bits == expected_weight.bits, "bits mismatch" assert w.groupsize == expected_weight.groupsize, "groupsize mismatch" + assert w.use_awq_kernel == expected_weight.use_awq_kernel, "use_awq_kernel mismatch" assert w.use_exllama == expected_weight.use_exllama, "use_exllama mismatch" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index ca39919ce67..925b0b2d6c8 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -1,8 +1,11 @@ -import torch -from loguru import logger +from dataclasses import dataclass from functools import lru_cache + import bitsandbytes as bnb +import torch from bitsandbytes.nn import Int8Params, Params4bit +from loguru import logger +from text_generation_server.utils.weights import Weight @lru_cache(1) @@ -12,6 +15,14 @@ def warn_deprecate_bnb(): ) +@dataclass +class BNBWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear8bitLt(self.weight, bias, has_fp16_weights=False, threshold=6.0) + + class Linear8bitLt(torch.nn.Module): def __init__( self, @@ -70,6 +81,22 @@ def forward(self, x: torch.Tensor): return out +@dataclass +class BNBFP4Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="fp4") + + +@dataclass +class BNBNF4Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return Linear4bit(self.weight, bias, quant_type="nf4") + + class Linear4bit(torch.nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index fd22b5c679e..f003f91452d 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -1,5 +1,23 @@ +from dataclasses import dataclass + import torch from EETQ import quant_weights, w8_a16_gemm +from text_generation_server.utils.weights import Weight + + +@dataclass +class EETQWeight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + try: + from text_generation_server.layers.eetq import EETQLinear + + return EETQLinear(self.weight, bias) + except ImportError: + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) class EETQLinear(torch.nn.Module): diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 55cba1ccebe..1a5acfa858d 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -1,12 +1,12 @@ -import torch -from typing import List, Union from dataclasses import dataclass +from typing import List, Union -from text_generation_server.utils.weights import WeightsLoader, Weights +import torch +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class Exl2Weight: +class Exl2Weight(Weight): """ Exllama2 exl2 quantized weights. """ @@ -25,6 +25,11 @@ def __post_init__(self): def device(self) -> torch.device: return self.q_weight.device + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.gptq import ExllamaQuantLinear + + return ExllamaQuantLinear(self, bias) + class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index b76af8f1131..b56f568a0c9 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,8 @@ -from enum import Enum, auto +from dataclasses import dataclass import torch from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.weights import Weight def get_fp8_linear() -> torch.nn.Module: @@ -37,6 +38,14 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +@dataclass +class Fp8Weight(Weight): + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + return get_fp8_linear()(self.weight, bias) + + class Fp8Linear(torch.nn.Module): def __init__( self, diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index aaa7a68a0b1..c98dbefc43c 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -1,24 +1,23 @@ -from dataclasses import dataclass -from loguru import logger import os +from dataclasses import dataclass from typing import List, Optional, Union -from safetensors import SafetensorError -from text_generation_server.utils.weights import Weights, WeightsLoader + import torch -from text_generation_server.utils.import_utils import ( - SYSTEM, -) +from loguru import logger +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader @dataclass -class GPTQWeight: +class GPTQWeight(Weight): qweight: torch.Tensor qzeros: torch.Tensor scales: torch.Tensor g_idx: Optional[torch.Tensor] bits: int groupsize: int + use_awq_kernel: bool use_exllama: bool def __post_init__(self): @@ -29,6 +28,50 @@ def __post_init__(self): def device(self) -> torch.device: return self.qweight.device + def get_linear(self, bias: torch.Tensor): + if self.use_awq_kernel: + if SYSTEM == "rocm": + raise NotImplementedError( + "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " + "to use Exllama/GPTQ kernels for AWQ inference." + ) + try: + from text_generation_server.layers.awq.quantize.qmodule import WQLinear + + return WQLinear( + w_bit=self.bits, + group_size=self.groupsize, + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bias=bias, + ) + except ImportError: + raise NotImplementedError( + "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" + ) + elif self.use_exllama: + try: + from text_generation_server.layers.gptq import ExllamaQuantLinear + except ImportError: + raise NotImplementedError( + f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" + ) + + return ExllamaQuantLinear(self, bias) + else: + from text_generation_server.layers.gptq.quant_linear import QuantLinear + + return QuantLinear( + self.qweight, + self.qzeros, + self.scales, + self.g_idx, + bias, + self.bits, + self.groupsize, + ) + try: major, _minor = torch.cuda.get_device_capability() @@ -45,6 +88,8 @@ def device(self) -> torch.device: if V2: from text_generation_server.layers.gptq.exllamav2 import ( QuantLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllamav2 import ( create_exllama_buffers, set_device, ) @@ -53,6 +98,8 @@ def device(self) -> torch.device: else: from text_generation_server.layers.gptq.exllama import ( Ex4bitLinear as ExllamaQuantLinear, + ) + from text_generation_server.layers.gptq.exllama import ( create_exllama_buffers, set_device, ) @@ -162,6 +209,7 @@ def get_weights_col_packed( g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=False, ) @@ -255,6 +303,7 @@ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int) g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) @@ -336,8 +385,8 @@ def get_weights_row(self, weights: Weights, prefix: str): use_exllama = False from text_generation_server.layers.gptq import ( - HAS_EXLLAMA, CAN_EXLLAMA, + HAS_EXLLAMA, GPTQWeight, ) @@ -389,6 +438,7 @@ def get_weights_row(self, weights: Weights, prefix: str): g_idx=g_idx, bits=self.bits, groupsize=self.groupsize, + use_awq_kernel=self.quantize == "awq", use_exllama=use_exllama, ) diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index babd86b09a3..a97cc43a152 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -1,7 +1,8 @@ from typing import Optional + import torch -from torch.nn import functional as F from text_generation_server.utils.import_utils import SYSTEM +from torch.nn import functional as F if SYSTEM == "rocm": try: @@ -90,167 +91,14 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: return F.linear(inp, self.weight, self.bias) -def get_linear(weight, bias, quantize): - if quantize is None: +def get_linear(weight, bias): + # Weights that are loaded through methods that are not + # quantization-aware are still bare tensors. We may want + # to change this in the future. + if isinstance(weight, torch.Tensor): if SYSTEM == "rocm": - linear = FastLinearROCm(weight, bias) - else: - linear = FastLinear(weight, bias) - elif quantize == "eetq": - try: - from text_generation_server.layers.eetq import EETQLinear - - linear = EETQLinear(weight, bias) - except ImportError: - raise ImportError( - "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" - ) - elif quantize == "fp8": - from text_generation_server.layers.fp8 import get_fp8_linear - - linear = get_fp8_linear()(weight, bias) - elif quantize == "bitsandbytes": - try: - from text_generation_server.layers.bnb import ( - warn_deprecate_bnb, - Linear8bitLt, - ) - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - warn_deprecate_bnb() - linear = Linear8bitLt( - weight, - bias, - has_fp16_weights=False, - threshold=6.0, - ) - if bias is not None: - linear.bias = nn.Parameter(bias) - elif quantize == "bitsandbytes-fp4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="fp4", - ) - elif quantize == "bitsandbytes-nf4": - try: - from text_generation_server.layers.bnb import Linear4bit - except ImportError: - raise NotImplementedError( - f"Bitsandbytes is missing install it with `pip install bitsandbytes`." - ) - linear = Linear4bit( - weight, - bias, - quant_type="nf4", - ) - elif quantize == "exl2": - from text_generation_server.layers.exl2 import Exl2Weight - - if not isinstance(weight, Exl2Weight): - raise NotImplementedError( - f"The passed weight is not `exl2` compatible, loader needs to be updated." - ) - - from text_generation_server.layers.gptq import ExllamaQuantLinear - - linear = ExllamaQuantLinear(weight, bias) - - elif quantize == "gptq": - from text_generation_server.layers.gptq import GPTQWeight - from text_generation_server.layers.marlin import ( - GPTQMarlinLinear, - GPTQMarlinWeight, - ) - - if isinstance(weight, GPTQMarlinWeight): - linear = GPTQMarlinLinear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, GPTQWeight): - if weight.use_exllama: - try: - from text_generation_server.layers.gptq import ( - ExllamaQuantLinear, - ) - except ImportError: - raise NotImplementedError( - f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" - ) - - linear = ExllamaQuantLinear(weight, bias) - else: - from text_generation_server.layers.gptq.quant_linear import QuantLinear - - linear = QuantLinear( - weight.qweight, - weight.qzeros, - weight.scales, - weight.g_idx, - bias, - weight.bits, - weight.groupsize, - ) + return FastLinearROCm(weight, bias) else: - raise NotImplementedError( - f"The passed weight is not `gptq` compatible, loader needs to be updated." - ) - - elif quantize == "awq": - from text_generation_server.layers.gptq import GPTQWeight + return FastLinear(weight, bias) - if not isinstance(weight, GPTQWeight): - raise NotImplementedError( - f"The passed weight is not `awq` compatible, loader needs to be updated." - ) - if SYSTEM == "rocm": - raise NotImplementedError( - "AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead " - "to use Exllama/GPTQ kernels for AWQ inference." - ) - try: - from text_generation_server.layers.awq.quantize.qmodule import WQLinear - - linear = WQLinear( - w_bit=weight.bits, - group_size=weight.groupsize, - qweight=weight.qweight, - qzeros=weight.qzeros, - scales=weight.scales, - bias=bias, - ) - except ImportError: - raise NotImplementedError( - "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" - ) - elif quantize == "marlin": - from text_generation_server.layers.marlin import ( - GPTQMarlin24Linear, - GPTQMarlin24Weight, - MarlinLinear, - MarlinWeight, - ) - - if isinstance(weight, GPTQMarlin24Weight): - linear = GPTQMarlin24Linear( - weight=weight, - bias=bias, - ) - elif isinstance(weight, MarlinWeight): - linear = MarlinLinear(weight=weight, bias=bias) - else: - raise NotImplementedError( - f"The passed weight is not `marlin` compatible, loader needs to be updated." - ) - else: - raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") - return linear + return weight.get_linear(bias) diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index 9777a47e53a..e7f017a4dac 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -7,7 +7,7 @@ from text_generation_server.layers.fp8 import fp8_quantize from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once -from text_generation_server.utils.weights import Weights, WeightsLoader +from text_generation_server.utils.weights import Weight, Weights, WeightsLoader try: import marlin_kernels @@ -63,8 +63,7 @@ def get_weights_col_packed( return weight def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = torch.cat( [weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1 @@ -101,8 +100,7 @@ def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int) return weight def get_weights_row(self, weights: Weights, prefix: str): - is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" - if is_marlin_24: + if self.is_marlin_24: try: B = weights.get_sharded(f"{prefix}.B_24", dim=0) except RuntimeError: @@ -201,7 +199,7 @@ def permute_scales(scales: torch.Tensor): @dataclass -class GPTQMarlinWeight: +class GPTQMarlinWeight(Weight): """ Repacked GPTQ Marlin weights. """ @@ -219,6 +217,12 @@ def __post_init__(self): assert self.g_idx.dtype == torch.int32 assert self.perm.dtype == torch.int32 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlinLinear( + weight=self, + bias=bias, + ) + def repack_gptq_for_marlin( *, @@ -376,6 +380,12 @@ def __post_init__(self): assert self.B_meta.dtype == torch.int16 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return GPTQMarlin24Linear( + weight=self, + bias=bias, + ) + class GPTQMarlin24Linear(nn.Module): def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]): @@ -567,7 +577,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): @dataclass -class MarlinWeight: +class MarlinWeight(Weight): """ Marlin weights. @@ -583,6 +593,9 @@ def __post_init__(self): assert self.B.dtype == torch.int32 assert self.s.dtype == torch.float16 + def get_linear(self, bias: torch.Tensor): + return MarlinLinear(weight=self, bias=bias) + class MarlinLinear(nn.Module): def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]): diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 011f105b632..9dddb8ae30c 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -77,7 +77,7 @@ def load(config, prefix: str, weights): quantize = config.quantize return TensorParallelHead( - get_linear(weight, bias=None, quantize=quantize), + get_linear(weight, bias=None), process_group=weights.process_group, should_gather=should_gather, ) @@ -134,7 +134,7 @@ def load_gate_up(cls, config, prefix: str, weights, bias: bool): raise NotImplementedError("packed_gate_up only implemented without bias") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -157,7 +157,7 @@ def load_qkv( raise NotImplementedError("packed_qkv only implemented for baichuan") else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -167,7 +167,7 @@ def load(cls, config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @classmethod @@ -177,7 +177,7 @@ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): for prefix in prefixes: weight = weights.get_weights_col(prefix) b = weights.get_tensor(f"{prefix}.bias") if bias else None - linears.append(get_linear(weight, b, config.quantize)) + linears.append(get_linear(weight, b)) linear = LayerConcat(linears) else: weight = weights.get_multi_weights_col(prefixes, dim=dim) @@ -186,7 +186,7 @@ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int): bias = torch.cat(b, dim=dim) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return cls(linear) @@ -205,7 +205,7 @@ def load(cls, config, prefix: str, weights, bias: bool): else: bias = None return cls( - get_linear(weight, bias, config.quantize), + get_linear(weight, bias), process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 49f5a81bc5c..c7b29d13ad3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -186,9 +186,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class FlashCohereAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 444116870c9..7426fc55e09 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -247,10 +247,10 @@ def _load_experts_quantized(config, prefix, weights, cls): if cls == TensorParallelRowLinear: expert_slice = expert_slice.t().contiguous() - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear, weights.process_group)) else: - linear = get_linear(expert_slice, None, config.quantize) + linear = get_linear(expert_slice, None) experts.append(cls(linear)) return experts diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index a3ce55213ff..5273b15d6b8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemma2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 34a7efa2550..829ad427dc9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -155,9 +155,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemmaAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index cbfcb1b8fc5..a55a4af3777 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -82,7 +82,7 @@ def _load_qkv_gptq(config, prefix: str, weights): bias = torch.cat(tensors, dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def _load_qkv(config, prefix: str, weights, head_size, num_heads): @@ -129,7 +129,7 @@ def _load_qkv(config, prefix: str, weights, head_size, num_heads): 3 * num_heads * head_size ], f"{weight.shape} != {[3 * num_heads * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -147,7 +147,7 @@ def load_row(config, prefix: str, weights, bias: bool): bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) @@ -163,7 +163,7 @@ def load_col(config, prefix: str, weights, bias: bool): else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) class FlashGPT2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 78832341c3f..5237a484941 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import List, Optional, Tuple import torch @@ -25,7 +26,6 @@ from torch import nn from transformers.activations import ACT2FN -from typing import Optional, List, Tuple from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import ( @@ -42,10 +42,16 @@ TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) +from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + Weights, +) if SYSTEM == "rocm": try: @@ -105,6 +111,19 @@ def load_attention(config, prefix: str, weights, layer_id): ) +@contextmanager +def no_fp8(weights: Weights): + weights_loader = weights.weights_loader + if ( + isinstance(weights_loader, DefaultWeightsLoader) + and weights_loader.weight_class is Fp8Weight + ): + weights_loader = DefaultWeightsLoader(UnquantizedWeight) + + with weights.use_loader(weights_loader): + yield + + class FlashLlamaAttention(torch.nn.Module): def __init__( self, @@ -330,12 +349,15 @@ def forward(self, hidden_states, adapter_data): class FlashLlamaLayer(nn.Module): def __init__(self, index, prefix, config, weights): super().__init__() - self.self_attn = FlashLlamaAttention( - index=index, - prefix=f"{prefix}.self_attn", - config=config, - weights=weights, - ) + + with no_fp8(weights): + self.self_attn = FlashLlamaAttention( + index=index, + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + self.mlp = LlamaMLP( prefix=f"{prefix}.mlp", config=config, weights=weights, index=index ) @@ -470,23 +492,27 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() - self.embed_tokens = TensorParallelEmbedding( - prefix=( - "model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens" - ), - weights=weights, - ) + with no_fp8(weights): + self.embed_tokens = TensorParallelEmbedding( + prefix=( + "model.embed_tokens" + if not prefix + else f"{prefix}.model.embed_tokens" + ), + weights=weights, + ) self.model = FlashLlamaModel(prefix, config, weights) if config.tie_word_embeddings: suffix = "model.embed_tokens" else: suffix = "lm_head" - self.lm_head = SpeculativeHead.load( - config, - prefix=suffix if not prefix else f"{prefix}.{suffix}", - weights=weights, - ) + with no_fp8(weights): + self.lm_head = SpeculativeHead.load( + config, + prefix=suffix if not prefix else f"{prefix}.{suffix}", + weights=weights, + ) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 49c0e9030fc..a1e36fc74bc 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): config.hidden_size, ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" - return TensorParallelColumnLinear( - get_linear(weight, bias=None, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=None)) def _load_experts(config, prefix: str, mat, weights): diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 85dcb2a6cd8..996642303aa 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -56,7 +56,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: @@ -81,7 +81,7 @@ def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): bias = weights.get_sharded(f"{prefix}.bias", dim=0) bias = bias.view(num_heads, 3, head_size).permute(1, 0, 2).reshape(-1) - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.use_parallel_residual: return linear else: diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 6c508264293..a1ce03b9c53 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -100,9 +100,7 @@ def _load_gqa(config, prefix: str, weights): ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" # this is the same as llama except for Phi uses bias=True - return TensorParallelColumnLinear( - get_linear(weight, bias=True, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=True)) class FlashPhiAttention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 65b40fed629..d7cad4805bd 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -31,7 +31,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) if config.parallel_attn: return linear else: diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 77b9d49cddd..2b939a1009d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -105,6 +105,7 @@ def _load_multi_mqa_gptq( g_idx=g_idx, bits=loader.bits, groupsize=loader.groupsize, + use_awq_kernel=loader.quantize == "awq", use_exllama=HAS_EXLLAMA, ) @@ -121,7 +122,7 @@ def _load_multi_mqa_gptq( bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = bias.to(device=weights.device) - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") @@ -193,7 +194,7 @@ def _load_multi_mqa( assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_col(config, prefix: str, weights, bias: bool): @@ -206,7 +207,7 @@ def load_col(config, prefix: str, weights, bias: bool): bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None - return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize)) + return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): @@ -221,7 +222,7 @@ def load_row(config, prefix: str, weights, bias: bool): else: bias = None return TensorParallelRowLinear( - get_linear(weight, bias, config.quantize), process_group=weights.process_group + get_linear(weight, bias), process_group=weights.process_group ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 19556f783c6..89471955e2f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -149,9 +149,7 @@ def _load_gqa(config, prefix: str, weights): else: bias = None - return TensorParallelColumnLinear( - get_linear(weight, bias=bias, quantize=config.quantize) - ) + return TensorParallelColumnLinear(get_linear(weight, bias=bias)) class Starcoder2Attention(torch.nn.Module): diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index daf3329abf0..735c3899480 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -34,7 +34,7 @@ TensorParallelEmbedding, TensorParallelRowLinear, ) -from text_generation_server.utils.weights import DefaultWeightsLoader +from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -698,7 +698,7 @@ def __init__(self, prefix, config, weights): self.dtype = weights.dtype # The vision and connector models are not quantized. - with weights.use_loader(DefaultWeightsLoader()): + with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)): self.vision_model = Idefics2VisionTransformer( prefix=( f"{prefix}.model.vision_model" if prefix else "model.vision_model" @@ -707,16 +707,12 @@ def __init__(self, prefix, config, weights): weights=weights, ) - quantize = config.quantize - try: - config.quantize = None - self.connector = Idefics2Connector( - prefix=f"{prefix}.model.connector" if prefix else "model.connector", - config=config, - weights=weights, - ) - finally: - config.quantize = quantize + config.quantize = None + self.connector = Idefics2Connector( + prefix=f"{prefix}.model.connector" if prefix else "model.connector", + config=config, + weights=weights, + ) self.config = config self.image_seq_len = config.perceiver_config.resampler_n_latents diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index fb09a8f1740..bbe9f45a2a1 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -75,7 +75,7 @@ def load_col(config, prefix, weights, bias): bias = bias.to(device=weights.device) else: bias = None - linear = get_linear(weight, bias, config.quantize) + linear = get_linear(weight, bias) return TensorParallelColumnLinear(linear) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index 07975bea01d..e8e22db8d33 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -1,11 +1,14 @@ -from typing import Optional -import os import json +import os from dataclasses import dataclass +from typing import Optional from huggingface_hub import hf_hub_download - -from text_generation_server.utils.weights import DefaultWeightsLoader, WeightsLoader +from text_generation_server.utils.weights import ( + DefaultWeightsLoader, + UnquantizedWeight, + WeightsLoader, +) @dataclass @@ -104,10 +107,30 @@ def get_loader( quantize=quantize, sym=quantizer_config.sym, ) + elif quantize == "bitsandbytes": + from text_generation_server.layers.bnb import BNBWeight + + return DefaultWeightsLoader(BNBWeight) + elif quantize == "bitsandbytes-fp4": + from text_generation_server.layers.bnb import BNBFP4Weight + + return DefaultWeightsLoader(BNBFP4Weight) + elif quantize == "bitsandbytes-nf4": + from text_generation_server.layers.bnb import BNBNF4Weight + + return DefaultWeightsLoader(BNBNF4Weight) + elif quantize == "eetq": + from text_generation_server.layers.eetq import EETQWeight + + return DefaultWeightsLoader(EETQWeight) elif quantize == "exl2": from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() + elif quantize == "fp8": + from text_generation_server.layers.fp8 import Fp8Weight + + return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader @@ -115,5 +138,7 @@ def get_loader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) + elif quantize is None: + return DefaultWeightsLoader(UnquantizedWeight) else: - return DefaultWeightsLoader() + raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index b530af23f9c..6876b70096a 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,9 +1,13 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum, auto from pathlib import Path from typing import Dict, List, Optional, Union -from safetensors import safe_open + import torch +from safetensors import safe_open +from text_generation_server.utils.import_utils import SYSTEM class WeightsLoader(ABC): @@ -62,7 +66,39 @@ def get_weights_row(self, weights: "Weights", prefix: str): ... +class Weight(ABC): + """Instances of this type implement unquantized/quantized/to-be + quantized weights.""" + + @abstractmethod + def get_linear(self, bias: torch.Tensor): + """Create a linear layer from this weight.""" + ... + + +@dataclass +class UnquantizedWeight: + weight: torch.Tensor + + def get_linear(self, bias: torch.Tensor): + from text_generation_server.layers.linear import FastLinear, FastLinearROCm + + if SYSTEM == "rocm": + return FastLinearROCm(self.weight, bias) + else: + return FastLinear(self.weight, bias) + + class DefaultWeightsLoader(WeightsLoader): + """Weight loader that loads (unquantized) Torch tensors.""" + + def __init__(self, weight_class): + """Create a loader. Weights will be wrapped using the given `weights_class`, + normally this will be `UnquantizedWeight`, but a quantizer-specific class + such as `Fp8Weight` can be used to quantize the weights during loading. + """ + self.weight_class = weight_class + """ Loader that uses tensors as-is with the exception of applying sharding and/or concatenation. @@ -74,16 +110,21 @@ def get_weights_col_packed( prefix: str, block_sizes: Union[int, List[int]], ): - return weights.get_packed_sharded( - f"{prefix}.weight", dim=0, block_sizes=block_sizes + + return self.weight_class( + weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ), ) def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] - return torch.cat(w, dim=dim) + return self.weight_class(torch.cat(w, dim=dim)) def get_weights_row(self, weights: "Weights", prefix: str): - return weights.get_sharded(f"{prefix}.weight", dim=1) + return self.weight_class( + weights.get_sharded(f"{prefix}.weight", dim=1), + ) class Weights: From 80adb5be16084720913b3e0622c74cd2c1ac2396 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 12:55:59 +0200 Subject: [PATCH 03/14] Hotfix: fix of use of unquantized weights in Gemma GQA loading (#2255) --- .../models/custom_modeling/flash_gemma2_modeling.py | 9 +++++---- .../models/custom_modeling/flash_gemma_modeling.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 5273b15d6b8..8526d515f8e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class Gemma2Config(PretrainedConfig): @@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 829ad427dc9..dfe6510c93a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -42,6 +42,7 @@ from text_generation_server.layers.layernorm import ( FastRMSNorm, ) +from text_generation_server.utils.weights import UnquantizedWeight class GemmaConfig(PretrainedConfig): @@ -144,16 +145,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) From 18db78f295c32d871fe833b9e87d6ca6f21a950d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 14:42:19 +0200 Subject: [PATCH 04/14] Hotfix: various GPT-based model fixes (#2256) --- server/text_generation_server/models/__init__.py | 5 +++++ .../models/custom_modeling/flash_neox_modeling.py | 15 +++++++++++---- .../custom_modeling/flash_starcoder2_modeling.py | 9 +++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ba980195d57..32156a068f7 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -573,6 +573,10 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + GPTNeoXConfig, + ) + return FlashCausalLM( model_id=model_id, model_class=FlashGPTNeoXForCausalLM, @@ -582,6 +586,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, + config_class=GPTNeoXConfig, ) elif sharded: return CausalLM( diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 996642303aa..623b164c338 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -24,7 +24,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from transformers.models.gpt_neox import GPTNeoXConfig +from transformers.models.gpt_neox import GPTNeoXConfig as TransformersGPTNeoXConfig from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( @@ -45,6 +45,13 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight + + +class GPTNeoXConfig(TransformersGPTNeoXConfig): + attribute_map = { + "num_key_value_heads": "num_attention_heads", + } def load_row(config, prefix: str, weights, bias: bool): @@ -65,10 +72,10 @@ def load_row(config, prefix: str, weights, bias: bool): def load_qkv(config, prefix: str, weights, num_heads, head_size, hidden_size): weight = weights.get_multi_weights_col([prefix], dim=0) - if isinstance(weight, torch.Tensor): + if isinstance(weight, UnquantizedWeight): # Only on non quantized versions - weight = ( - weight.view( + weight.weight = ( + weight.weight.view( num_heads, 3, head_size, diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 89471955e2f..cfa891d45e9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -45,6 +45,7 @@ from text_generation_server.layers.rotary import ( PositionRotaryEmbedding, ) +from text_generation_server.utils.weights import UnquantizedWeight class Starcoder2Config(PretrainedConfig): @@ -129,16 +130,16 @@ def _load_gqa(config, prefix: str, weights): dim=0, ) - if config.quantize not in ["gptq", "awq", "marlin"]: - weight = weight.to(dtype=weights.dtype).to(device=weights.device) + if isinstance(weight, UnquantizedWeight): + weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.hidden_size // config.num_attention_heads num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() - assert list(weight.shape) == [ + assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, - ], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" + ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" if config.use_bias: w = [ From 3b41e93a090d33140084427662d68f6690c4938e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 14:42:35 +0200 Subject: [PATCH 05/14] Hotfix: fix MPT after recent refactor (#2257) --- .../models/causal_lm.py | 9 ++++- .../models/custom_modeling/mpt_modeling.py | 38 +++++++++---------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 0ea82b1e5b0..87b9969ae4f 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -492,7 +492,7 @@ def __len__(self): @dataclass -class CausalLMBatchKeysLast(Batch): +class CausalLMBatchKeysLast(CausalLMBatch): keys_head_dim_last: bool = False @@ -544,7 +544,12 @@ def __init__( config.quantize = quantize config.speculator = speculator if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = config.pad_token_id + if config.pad_token_id is not None: + tokenizer.pad_token_id = config.pad_token_id + elif config.eos_token_id is not None: + tokenizer.pad_token_id = config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id torch.distributed.barrier(group=self.process_group) weights_loader = get_loader( diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index bbe9f45a2a1..04c547b2fe0 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -337,17 +337,17 @@ def __init__( weights, ): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop if self.n_heads % weights.process_group.size() != 0: raise ValueError( @@ -430,17 +430,17 @@ class MultiQueryAttention(nn.Module): def __init__(self, config, prefix, weights): super().__init__() - attn_impl = config.attn_config["attn_impl"] - self.attn_impl = config.attn_config["attn_impl"] - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] + attn_impl = config.attn_config.attn_impl + self.attn_impl = config.attn_config.attn_impl + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln self.d_model = config.d_model d_model = config.d_model self.n_heads = config.n_heads - self.softmax_scale = config.attn_config["softmax_scale"] + self.softmax_scale = config.attn_config.softmax_scale if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.head_dim) - self.attn_dropout_p = config.attn_config["attn_pdrop"] + self.attn_dropout_p = config.attn_config.attn_pdrop # self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device) self.Wqkv = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias @@ -614,9 +614,9 @@ class MPTBlock(nn.Module): def __init__(self, config, prefix, weights): super().__init__() self.prefix = prefix - if config.attn_config["attn_type"] != "multihead_attention": + if config.attn_config.attn_type != "multihead_attention": raise NotImplementedError( - f"""Not implemented attn {config.attn_config["attn_type"]}""" + f"""Not implemented attn {config.attn_config.attn_type}""" ) resid_pdrop = config.resid_pdrop if config.no_bias: @@ -789,11 +789,11 @@ def __init__(self, prefix: str, config, weights): self.world_size = weights.process_group.size() self.rank = weights.process_group.rank() self.n_heads = config.n_heads - self.attn_impl = config.attn_config["attn_impl"] - self.prefix_lm = config.attn_config["prefix_lm"] - self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"] - self.alibi = config.attn_config["alibi"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] + self.attn_impl = config.attn_config.attn_impl + self.prefix_lm = config.attn_config.prefix_lm + self.attn_uses_sequence_id = config.attn_config.attn_uses_sequence_id + self.alibi = config.attn_config.alibi + self.alibi_bias_max = config.attn_config.alibi_bias_max if config.init_device == "mixed": if dist.get_local_rank() == 0: config.init_device = "cpu" From 3f37a667749b8130225f4ef727caf30f73f304a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 15:59:00 +0200 Subject: [PATCH 06/14] Hotfix: pass through model revision in `VlmCausalLM` (#2258) --- server/text_generation_server/models/vlm_causal_lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index ace4880588d..f869f8b53ea 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -261,7 +261,12 @@ def __init__( **processor_kwargs, ) self.batch_class = batch_class - super().__init__(model_id=model_id, **kwargs) + super().__init__( + model_id=model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **kwargs, + ) @property def batch_type(self) -> Type[VlmCausalLMBatch]: From 4c19593a903dffe3b7d7e4edb479445aff0f001f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Fri, 19 Jul 2024 16:17:56 +0200 Subject: [PATCH 07/14] usage stats and crash reports (#2220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * draft of usage stats * fix wrong link * launcher doesn't need sysinfo dep * only tokenizer class instead of hole struct * unused import * fix clippy errors * update openAPI doc * cargo fmt * fix error in passing flags to router * try again to update docs * run pre-commit locally * Update router/src/main.rs Co-authored-by: Hugo Larcher * Update router/src/main.rs Co-authored-by: Hugo Larcher * on crash use anonymous error event * delete json_output and ngrok * more robust way of checking if is in container * more robust nvidia smi * parse xpu more robustly * fix errors * add nvidia-smi details in docs * cargo fmt * fix clippy * should make docs check pass * Update router/src/usage_stats.rs Co-authored-by: Hugo Larcher * error reason can't be in nested json * cargo fmt --------- Co-authored-by: Hugo Larcher Co-authored-by: Erik Kaunismäki --- Cargo.lock | 53 +++- docs/source/basic_tutorials/launcher.md | 16 ++ docs/source/usage_statistics.md | 73 +++++ launcher/src/main.rs | 16 ++ router/Cargo.toml | 4 + router/src/lib.rs | 8 +- router/src/main.rs | 86 +++++- router/src/usage_stats.rs | 355 ++++++++++++++++++++++++ 8 files changed, 599 insertions(+), 12 deletions(-) create mode 100644 docs/source/usage_statistics.md create mode 100644 router/src/usage_stats.rs diff --git a/Cargo.lock b/Cargo.lock index ffc98baa9de..923b5cbef6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -801,6 +801,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "ctrlc" version = "3.4.4" @@ -3402,9 +3423,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -3650,15 +3671,16 @@ checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" [[package]] name = "sysinfo" -version = "0.30.12" +version = "0.30.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "732ffa00f53e6b2af46208fba5718d9662a421049204e156328b66791ffa15ae" +checksum = "0a5b4ddaee55fb2bea2bf0e5000747e5f5c0de765e5a5ff87f4cd106439f4bb3" dependencies = [ "cfg-if", "core-foundation-sys", "libc", "ntapi", "once_cell", + "rayon", "windows", ] @@ -3805,6 +3827,7 @@ dependencies = [ "axum-tracing-opentelemetry", "base64 0.22.1", "clap", + "csv", "futures", "futures-util", "hf-hub", @@ -3826,6 +3849,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "sysinfo", "text-generation-client", "thiserror", "tokenizers", @@ -3837,6 +3861,7 @@ dependencies = [ "tracing-subscriber", "utoipa", "utoipa-swagger-ui", + "uuid", "vergen", ] @@ -4508,9 +4533,25 @@ dependencies = [ [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +dependencies = [ + "getrandom", + "rand", + "uuid-macro-internal", +] + +[[package]] +name = "uuid-macro-internal" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "ee1cd046f83ea2c4e920d6ee9f7c3537ef928d75dce5d84a87c2c5d6b3999a3a" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] [[package]] name = "v_frame" diff --git a/docs/source/basic_tutorials/launcher.md b/docs/source/basic_tutorials/launcher.md index 5e40146f58a..77f884902e6 100644 --- a/docs/source/basic_tutorials/launcher.md +++ b/docs/source/basic_tutorials/launcher.md @@ -424,6 +424,22 @@ Options: [env: LORA_ADAPTERS=] +``` +## DISABLE_USAGE_STATS +```shell + --disable-usage-stats + Disable sending of all usage statistics + + [env: DISABLE_USAGE_STATS=] + +``` +## DISABLE_CRASH_REPORTS +```shell + --disable-crash-reports + Disable sending of crash reports, but allow anonymous usage statistics + + [env: DISABLE_CRASH_REPORTS=] + ``` ## HELP ```shell diff --git a/docs/source/usage_statistics.md b/docs/source/usage_statistics.md new file mode 100644 index 00000000000..adf0d70fcd0 --- /dev/null +++ b/docs/source/usage_statistics.md @@ -0,0 +1,73 @@ + +# Collection of Usage Statistics + +Text Generation Inference collects anonymous usage statistics to help us improve the service. The collected data is used to improve TGI and to understand what causes failures. The data is collected transparently and any sensitive information is omitted. + +Data is sent twice, once on server startup and once when server stops. Also, usage statistics are only enabled when TGI is running in docker to avoid collecting data then TGI runs directly on the host machine. + +## What data is collected + +The code that collects the data is available [here](https://github.com/huggingface/text-generation-inference/blob/main/router/src/usage_stats.rs). +As of release 2.1.2 this is an example of the data collected: + +- From the TGI configuration: +```json +{ + "event_type": "start", + "disable_grammar_support": false, + "max_batch_prefill_tokens": 4096, + "max_batch_size": null, + "max_batch_total_tokens": null, + "max_best_of": 2, + "max_client_batch_size": 4, + "max_concurrent_requests": 128, + "max_input_tokens": 1024, + "max_stop_sequences": 4, + "max_top_n_tokens": 5, + "max_total_tokens": 2048, + "max_waiting_tokens": 20, + "messages_api_enabled": false, + "model_config": { + "model_type": "Bloom" + }, + "revision": null, + "tokenizer_class": "BloomTokenizerFast", + "validation_workers": 2, + "waiting_served_ratio": 1.2, + "docker_label": "latest", + "git_sha": "cfc118704880453d29bcbe4fbbd91dda501cf5fe", + "nvidia_env": { + "name": "NVIDIA A10G", + "pci_bus_id": "00000000:00:1E.0", + "driver_version": "535.183.01", + "pstate": "P8", + "pcie_link_gen_max": "4", + "pcie_link_gen_current": "1", + "temperature_gpu": "31", + "utilization_gpu": "0 %", + "utilization_memory": "0 %", + "memory_total": "23028 MiB", + "memory_free": "22515 MiB", + "memory_used": "0 MiB", + "reset_status_reset_required": "No", + "reset_status_drain_and_reset_recommended": "No", + "compute_cap": "8.6", + "ecc_errors_corrected_volatile_total": "0", + "mig_mode_current": "[N/A]", + "power_draw_instant": "10.86 W", + "power_limit": "300.00 W" + }, + "system_env": { + "cpu_count": 16, + "cpu_type": "AMD EPYC 7R32", + "total_memory": 66681196544, + "architecture": "x86_64", + "platform": "linux-unix-x86_64" + } +} + +``` + +## How to opt-out + +You can easily opt out by passing the `--disable-usage-stats` to the text-generation-launcher command. This will disable all usage statistics. You can also pass `--disable-crash-reports` which disables sending specific crash reports, but allows anonymous usage statistics. diff --git a/launcher/src/main.rs b/launcher/src/main.rs index d2ca38e5a0e..ef7b4712364 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -457,6 +457,14 @@ struct Args { /// startup that will be available to callers via the `adapter_id` field in a request. #[clap(long, env)] lora_adapters: Option, + + /// Disable sending of all usage statistics + #[clap(default_value = "false", long, env)] + disable_usage_stats: bool, + + /// Disable sending of crash reports, but allow anonymous usage statistics + #[clap(default_value = "false", long, env)] + disable_crash_reports: bool, } #[derive(Debug)] @@ -1201,6 +1209,14 @@ fn spawn_webserver( args.model_id, ]; + // Pass usage stats flags to router + if args.disable_usage_stats { + router_args.push("--disable-usage-stats".to_string()); + } + if args.disable_crash_reports { + router_args.push("--disable-crash-reports".to_string()); + } + // Grammar support if args.disable_grammar_support { router_args.push("--disable-grammar-support".to_string()); diff --git a/router/Cargo.toml b/router/Cargo.toml index 60fb5c9d109..0fc700a0659 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -52,6 +52,10 @@ regex = "1.10.3" once_cell = "1.19.0" image = "0.25.1" base64 = { workspace = true } +sysinfo = "0.30.13" +uuid = { version = "1.9.1", default-features = false, features = ["v4", "fast-rng", "macro-diagnostics"] } +csv = "1.3.0" + [build-dependencies] vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] } diff --git a/router/src/lib.rs b/router/src/lib.rs index f856406d6c1..52926c6c834 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -7,6 +7,8 @@ mod validation; #[cfg(feature = "kserve")] mod kserve; +pub mod usage_stats; + use serde::{Deserialize, Serialize}; use tracing::warn; use utoipa::ToSchema; @@ -40,13 +42,13 @@ pub struct HubModelInfo { pub pipeline_tag: Option, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ChatTemplate { name: String, template: String, } -#[derive(Debug, Clone, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(untagged)] pub enum ChatTemplateVersions { Single(String), @@ -55,7 +57,7 @@ pub enum ChatTemplateVersions { use std::path::Path; -#[derive(Debug, Clone, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HubTokenizerConfig { pub chat_template: Option, pub completion_template: Option, diff --git a/router/src/main.rs b/router/src/main.rs index b060d73cf30..bfc779131a6 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,6 +14,7 @@ use std::io::BufReader; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::{Path, PathBuf}; use text_generation_router::config::Config; +use text_generation_router::usage_stats; use text_generation_router::{ server, HubModelInfo, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig, }; @@ -87,6 +88,10 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + #[clap(long, env, default_value_t)] + disable_usage_stats: bool, + #[clap(long, env, default_value_t)] + disable_crash_reports: bool, } #[derive(Debug, Subcommand)] @@ -128,6 +133,8 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + disable_usage_stats, + disable_crash_reports, command, } = args; @@ -324,6 +331,7 @@ async fn main() -> Result<(), RouterError> { tracing::warn!("Could not find tokenizer config locally and no API specified"); HubTokenizerConfig::default() }); + let tokenizer_class = tokenizer_config.tokenizer_class.clone(); let tokenizer: Option = tokenizer_filename.and_then(|filename| { let mut tokenizer = Tokenizer::from_file(filename).ok(); @@ -378,8 +386,47 @@ async fn main() -> Result<(), RouterError> { } }; + // Only send usage stats when TGI is run in container and the function returns Some + let is_container = matches!(usage_stats::is_container(), Ok(true)); + + let user_agent = if !disable_usage_stats && is_container { + let reduced_args = usage_stats::Args::new( + config.clone(), + tokenizer_class, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + ); + Some(usage_stats::UserAgent::new(reduced_args)) + } else { + None + }; + + if let Some(ref ua) = user_agent { + let start_event = + usage_stats::UsageStatsEvent::new(ua.clone(), usage_stats::EventType::Start, None); + tokio::spawn(async move { + start_event.send().await; + }); + }; + // Run server - server::run( + let result = server::run( master_shard_uds_path, model_info, compat_return_full_text, @@ -410,8 +457,41 @@ async fn main() -> Result<(), RouterError> { max_client_batch_size, print_schema_command, ) - .await?; - Ok(()) + .await; + + match result { + Ok(_) => { + if let Some(ref ua) = user_agent { + let stop_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Stop, + None, + ); + stop_event.send().await; + }; + Ok(()) + } + Err(e) => { + if let Some(ref ua) = user_agent { + if !disable_crash_reports { + let error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some(e.to_string()), + ); + error_event.send().await; + } else { + let unknow_error_event = usage_stats::UsageStatsEvent::new( + ua.clone(), + usage_stats::EventType::Error, + Some("unknow_error".to_string()), + ); + unknow_error_event.send().await; + } + }; + Err(RouterError::WebServer(e)) + } + } } /// Init logging using env variables LOG_LEVEL and LOG_FORMAT: diff --git a/router/src/usage_stats.rs b/router/src/usage_stats.rs new file mode 100644 index 00000000000..8559ae90f37 --- /dev/null +++ b/router/src/usage_stats.rs @@ -0,0 +1,355 @@ +use crate::config::Config; +use csv::ReaderBuilder; +use reqwest::header::HeaderMap; +use serde::Serialize; +use std::{ + fs::File, + io::{self, BufRead}, + path::Path, + process::Command, + time::Duration, +}; +use uuid::Uuid; + +const TELEMETRY_URL: &str = "https://huggingface.co/api/telemetry/tgi"; + +#[derive(Debug, Clone, Serialize)] +pub struct UserAgent { + pub uid: String, + pub args: Args, + pub env: Env, +} + +impl UserAgent { + pub fn new(reduced_args: Args) -> Self { + Self { + uid: Uuid::new_v4().to_string(), + args: reduced_args, + env: Env::new(), + } + } +} + +#[derive(Serialize, Debug)] +pub enum EventType { + Start, + Stop, + Error, +} + +#[derive(Debug, Serialize)] +pub struct UsageStatsEvent { + user_agent: UserAgent, + event_type: EventType, + #[serde(skip_serializing_if = "Option::is_none")] + error_reason: Option, +} + +impl UsageStatsEvent { + pub fn new(user_agent: UserAgent, event_type: EventType, error_reason: Option) -> Self { + Self { + user_agent, + event_type, + error_reason, + } + } + pub async fn send(&self) { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "application/json".parse().unwrap()); + let body = serde_json::to_string(&self).unwrap(); + let client = reqwest::Client::new(); + let _ = client + .post(TELEMETRY_URL) + .headers(headers) + .body(body) + .timeout(Duration::from_secs(5)) + .send() + .await; + } +} + +#[derive(Debug, Clone, Serialize)] +pub struct Args { + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, +} + +impl Args { + #[allow(clippy::too_many_arguments)] + pub fn new( + model_config: Option, + tokenizer_config: Option, + max_concurrent_requests: usize, + max_best_of: usize, + max_stop_sequences: usize, + max_top_n_tokens: u32, + max_input_tokens: usize, + max_total_tokens: usize, + waiting_served_ratio: f32, + max_batch_prefill_tokens: u32, + max_batch_total_tokens: Option, + max_waiting_tokens: usize, + max_batch_size: Option, + revision: Option, + validation_workers: usize, + messages_api_enabled: bool, + disable_grammar_support: bool, + max_client_batch_size: usize, + disable_usage_stats: bool, + disable_crash_reports: bool, + ) -> Self { + Self { + model_config, + tokenizer_config, + max_concurrent_requests, + max_best_of, + max_stop_sequences, + max_top_n_tokens, + max_input_tokens, + max_total_tokens, + waiting_served_ratio, + max_batch_prefill_tokens, + max_batch_total_tokens, + max_waiting_tokens, + max_batch_size, + revision, + validation_workers, + messages_api_enabled, + disable_grammar_support, + max_client_batch_size, + disable_usage_stats, + disable_crash_reports, + } + } +} + +/// This is more or less a copy of the code from the `text-generation-launcher` crate to avoid a dependency +#[derive(Serialize, Debug, Clone)] +pub struct Env { + git_sha: &'static str, + docker_label: &'static str, + nvidia_info: Option>, + xpu_info: Option>, + system_env: SystemInfo, +} + +#[derive(Debug, Serialize, Clone)] +struct NvidiaSmiInfo { + name: String, + pci_bus_id: String, + driver_version: String, + pstate: String, + pcie_link_gen_max: String, + pcie_link_gen_current: String, + temperature_gpu: String, + utilization_gpu: String, + utilization_memory: String, + memory_total: String, + memory_free: String, + memory_used: String, + reset_status_reset_required: String, + reset_status_drain_and_reset_recommended: String, + compute_cap: String, + ecc_errors_corrected_volatile_total: String, + mig_mode_current: String, + power_draw_instant: String, + power_limit: String, +} + +impl NvidiaSmiInfo { + fn new() -> Option> { + let output = Command::new("nvidia-smi") + .args([ + "--query-gpu=name,pci.bus_id,driver_version,pstate,pcie.link.gen.max,pcie.link.gen.gpucurrent,temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used,reset_status.reset_required,reset_status.drain_and_reset_recommended,compute_cap,ecc.errors.corrected.volatile.total,mig.mode.current,power.draw.instant,power.limit", + "--format=csv" + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + + let mut rdr = ReaderBuilder::new() + .has_headers(true) + .from_reader(stdout.as_bytes()); + + let mut infos = Vec::new(); + + for result in rdr.records() { + let record = result.ok()?; + infos.push(NvidiaSmiInfo { + name: record[0].to_string(), + pci_bus_id: record[1].to_string(), + driver_version: record[2].to_string(), + pstate: record[3].to_string(), + pcie_link_gen_max: record[4].to_string(), + pcie_link_gen_current: record[5].to_string(), + temperature_gpu: record[6].to_string(), + utilization_gpu: record[7].to_string(), + utilization_memory: record[8].to_string(), + memory_total: record[9].to_string(), + memory_free: record[10].to_string(), + memory_used: record[11].to_string(), + reset_status_reset_required: record[12].to_string(), + reset_status_drain_and_reset_recommended: record[13].to_string(), + compute_cap: record[14].to_string(), + ecc_errors_corrected_volatile_total: record[15].to_string(), + mig_mode_current: record[16].to_string(), + power_draw_instant: record[17].to_string(), + power_limit: record[18].to_string(), + }); + } + + Some(infos) + } +} + +#[derive(Debug, Serialize, Clone)] +struct XpuSmiInfo { + device_id: usize, + gpu_utilization: f32, + gpu_power: f32, + gpu_core_temperature: f32, + gpu_memory_bandwidth_utilization: f32, +} + +impl XpuSmiInfo { + /// based on this https://github.com/intel/xpumanager/blob/master/doc/smi_user_guide.md#dump-the-device-statistics-in-csv-format + fn new() -> Option> { + let output = Command::new("xpu-smi") + .args([ + "dump", "-d", "-1", "-m", + "0,1,3,17", // Metrics IDs: GPU Utilization, GPU Power, GPU Core Temperature, GPU Memory Bandwidth Utilization + "-n", "1", "-j", + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let stdout = String::from_utf8(output.stdout).ok()?; + let mut infos = Vec::new(); + + let json_data: serde_json::Value = match serde_json::from_str(&stdout) { + Ok(data) => data, + Err(_) => return None, + }; + + if let Some(metrics_data) = json_data.as_array() { + for entry in metrics_data { + let device_id = entry["deviceId"].as_u64()? as usize; + let gpu_utilization = entry["metrics"][0].as_f64()? as f32; + let gpu_power = entry["metrics"][1].as_f64()? as f32; + let gpu_core_temperature = entry["metrics"][2].as_f64()? as f32; + let gpu_memory_bandwidth_utilization = entry["metrics"][3].as_f64()? as f32; + + infos.push(XpuSmiInfo { + device_id, + gpu_utilization, + gpu_power, + gpu_core_temperature, + gpu_memory_bandwidth_utilization, + }); + } + } + + Some(infos) + } +} + +#[derive(Serialize, Debug, Clone)] +pub struct SystemInfo { + cpu_count: usize, + cpu_type: String, + total_memory: u64, + architecture: String, + platform: String, +} + +impl SystemInfo { + fn new() -> Self { + let mut system = sysinfo::System::new_all(); + system.refresh_all(); + + let cpu_count = system.cpus().len(); + let cpu_type = system.cpus()[0].brand().to_string(); + let total_memory = system.total_memory(); + let architecture = std::env::consts::ARCH.to_string(); + let platform = format!( + "{}-{}-{}", + std::env::consts::OS, + std::env::consts::FAMILY, + std::env::consts::ARCH + ); + Self { + cpu_count, + cpu_type, + total_memory, + architecture, + platform, + } + } +} + +impl Default for Env { + fn default() -> Self { + Self::new() + } +} + +impl Env { + pub fn new() -> Self { + Self { + system_env: SystemInfo::new(), + nvidia_info: NvidiaSmiInfo::new(), + xpu_info: XpuSmiInfo::new(), + git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), + docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), + } + } +} + +pub fn is_container() -> io::Result { + let path = Path::new("/proc/self/cgroup"); + let file = File::open(path)?; + let reader = io::BufReader::new(file); + + for line in reader.lines() { + let line = line?; + // Check for common container runtimes + if line.contains("/docker/") + || line.contains("/docker-") + || line.contains("/kubepods/") + || line.contains("/kubepods-") + || line.contains("containerd") + || line.contains("crio") + || line.contains("podman") + { + return Ok(true); + } + } + Ok(false) +} From 40f5dc3ed6d87a980018163ffee588bc3726aedd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Kaunism=C3=A4ki?= Date: Fri, 19 Jul 2024 16:34:04 +0200 Subject: [PATCH 08/14] add usage stats to toctree (#2260) quick fix --- docs/source/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 119c5662e6c..e97c00aa260 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -21,6 +21,8 @@ title: Messages API - local: architecture title: Internal Architecture + - local: usage_statistics + title: Usage Statistics title: Getting started - sections: - local: basic_tutorials/consuming_tgi From 68a9685f1b6beabe8a9c5893d54decc9d65aa299 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 19 Jul 2024 11:12:02 -0400 Subject: [PATCH 09/14] fix: adjust default tool choice (#2244) * fix: adjust default tool choice * feat: improve tool choice syntax and response parsing/errors * fix: remove dev tests * feat: add ToolChoice to docs --- docs/openapi.json | 15 ++- router/src/infer/mod.rs | 233 ++++++++++++++++++++-------------------- router/src/lib.rs | 20 ++-- router/src/server.rs | 53 +++++---- 4 files changed, 167 insertions(+), 154 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index 3e7050abbcb..7000c7b7c01 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -909,7 +909,7 @@ "tool_choice": { "allOf": [ { - "$ref": "#/components/schemas/ToolType" + "$ref": "#/components/schemas/ToolChoice" } ], "nullable": true @@ -2035,6 +2035,14 @@ } } }, + "ToolChoice": { + "allOf": [ + { + "$ref": "#/components/schemas/ToolType" + } + ], + "nullable": true + }, "ToolType": { "oneOf": [ { @@ -2055,6 +2063,11 @@ "$ref": "#/components/schemas/FunctionName" } } + }, + { + "type": "object", + "default": null, + "nullable": true } ] }, diff --git a/router/src/infer/mod.rs b/router/src/infer/mod.rs index f3b10450ae0..db9070d4806 100644 --- a/router/src/infer/mod.rs +++ b/router/src/infer/mod.rs @@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck; use crate::validation::{ValidGenerateRequest, Validation, ValidationError}; use crate::{ ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig, - HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, + HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice, }; use crate::{ FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools, @@ -332,126 +332,131 @@ impl ChatTemplate { pub struct ToolGrammar {} impl ToolGrammar { + // find a tool by name + fn find_tool_by_name(tools: &[Tool], name: &str) -> Result { + tools + .iter() + .find(|tool| tool.function.name == name) + .cloned() + .ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name))) + } + pub fn apply( tools: Option>, - tool_choice: Option, + tool_choice: ToolChoice, ) -> Result, InferError> { - if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) { - // let tool_prompt = tool_prompt.unwrap_or_default(); - let tools_to_use = match tool_choice { - ToolType::FunctionName(name) => { - vec![req_tools - .iter() - .find(|tool| tool.function.name == *name) - .unwrap_or_else(|| panic!("Tool with name {} not found", name)) - .clone()] - } - ToolType::Function { function } => { - let tool = req_tools - .iter() - .find(|tool| tool.function.name == function.name) - .unwrap_or_else(|| panic!("Tool with name {} not found", function.name)) - .clone(); - vec![tool] + // if no tools are provided, we return None + let tools = match tools { + Some(tools) if !tools.is_empty() => tools, + _ => return Ok(None), + }; + + let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf); + + // if tools are provided and no tool_choice we default to the OneOf + let tools_to_use = match tool_choice { + ToolType::FunctionName(name) => { + vec![Self::find_tool_by_name(&tools, &name)?] + } + ToolType::Function { function } => { + vec![Self::find_tool_by_name(&tools, &function.name)?] + } + ToolType::OneOf => tools, + ToolType::NoTool => return Ok(None), + }; + + // adds the error notification function for LLM feedback if required + let mut text_response_properties = Map::new(); + text_response_properties.insert( + "error".to_string(), + serde_json::json!({ + "type": "string", + "description": "The error or issue to notify" + }), + ); + text_response_properties.insert( + "_name".to_string(), + serde_json::json!({ + "type": "string", + "const": "notify_error" + }), + ); + + let functions: HashMap = tools_to_use + .iter() + .map(|tool| { + let func = tool.function.clone(); + + // Clone the existing parameters, which are expected to be a JSON object + let mut params = if let Value::Object(params) = &func.arguments { + params.clone() + } else { + Map::new() + }; + + // Insert the function's description at the top level, outside of properties + params.insert( + "description".to_string(), + Value::String(func.description.clone().unwrap_or_default()), + ); + + // Ensure 'properties' exists and is an object + let properties = params + .entry("properties".to_string()) + .or_insert_with(|| json!({})) + .as_object_mut() + .unwrap(); + + // Insert the constant for the function name inside 'properties' + properties.insert( + "_name".to_string(), + json!({ + "type": "string", + "const": func.name.clone(), + // "description": "The name of the function" + }), + ); + + // Check if 'required' exists, and it is an array. If not, create an empty array. + let required = params + .entry("required".to_string()) + .or_insert_with(|| json!([])) + .as_array_mut() + .unwrap(); + + // Add 'name' to the 'required' array if it is not already present + if !required.iter().any(|r| r == "_name") { + required.push(json!("_name")); } - ToolType::OneOf => req_tools.to_owned(), - }; - // adds the error notification function for LLM feedback if required - let mut text_response_properties = Map::new(); - text_response_properties.insert( - "error".to_string(), - serde_json::json!({ - "type": "string", - "description": "The error or issue to notify" - }), - ); - text_response_properties.insert( - "_name".to_string(), + (func.name, Value::Object(params)) + }) + .chain([( + "notify_error".to_string(), serde_json::json!({ - "type": "string", - "const": "notify_error" + "properties": text_response_properties, + "required": ["error", "_name"], + "type": "object" }), - ); - - let functions: HashMap = tools_to_use - .iter() - .map(|tool| { - let func = tool.function.clone(); - - // Clone the existing parameters, which are expected to be a JSON object - let mut params = if let Value::Object(params) = &func.arguments { - params.clone() - } else { - Map::new() - }; - - // Insert the function's description at the top level, outside of properties - params.insert( - "description".to_string(), - Value::String(func.description.clone().unwrap_or_default()), - ); - - // Ensure 'properties' exists and is an object - let properties = params - .entry("properties".to_string()) - .or_insert_with(|| json!({})) - .as_object_mut() - .unwrap(); - - // Insert the constant for the function name inside 'properties' - properties.insert( - "_name".to_string(), - json!({ - "type": "string", - "const": func.name.clone(), - // "description": "The name of the function" - }), - ); - - // Check if 'required' exists, and it is an array. If not, create an empty array. - let required = params - .entry("required".to_string()) - .or_insert_with(|| json!([])) - .as_array_mut() - .unwrap(); - - // Add 'name' to the 'required' array if it is not already present - if !required.iter().any(|r| r == "_name") { - required.push(json!("_name")); - } - - (func.name, Value::Object(params)) - }) - .chain([( - "notify_error".to_string(), - serde_json::json!({ - "properties": text_response_properties, - "required": ["error", "_name"], - "type": "object" - }), - )]) - .collect(); - - let tools = Tools { - functions_map: FunctionsMap { functions }, - properties: Properties { - function: tools_to_use - .iter() - .map(|tool| FunctionRef { - ref_path: format!("#/$functions/{}", tool.function.name.clone()), - }) - .chain(std::iter::once(FunctionRef { - ref_path: "#/$functions/notify_error".to_string(), - })) - .collect(), - }, - }; - - return Ok(Some(tools)); - } - // Err(InferError::ToolError("No tools provided".to_string())) - Ok(None) + )]) + .collect(); + + let tools = Tools { + functions_map: FunctionsMap { functions }, + properties: Properties { + function: tools_to_use + .iter() + .map(|tool| FunctionRef { + ref_path: format!("#/$functions/{}", tool.function.name.clone()), + }) + .chain(std::iter::once(FunctionRef { + ref_path: "#/$functions/notify_error".to_string(), + })) + .collect(), + }, + }; + + Ok(Some(tools)) } } diff --git a/router/src/lib.rs b/router/src/lib.rs index 52926c6c834..b6e0d09dfac 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -826,7 +826,7 @@ pub(crate) struct ChatRequest { /// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter. #[serde(default)] #[schema(nullable = true, example = "null")] - pub tool_choice: Option, + pub tool_choice: ToolChoice, /// Response format constraints for the generation. /// @@ -848,6 +848,7 @@ pub enum ToolType { OneOf, FunctionName(String), Function { function: FunctionName }, + NoTool, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] @@ -855,27 +856,26 @@ pub struct FunctionName { pub name: String, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)] #[serde(from = "ToolTypeDeserializer")] pub struct ToolChoice(pub Option); #[derive(Deserialize)] #[serde(untagged)] enum ToolTypeDeserializer { - None(Option), - Some(ToolType), + String(String), + ToolType(ToolType), } impl From for ToolChoice { fn from(value: ToolTypeDeserializer) -> Self { match value { - ToolTypeDeserializer::None(opt) => match opt.as_deref() { - Some("none") => ToolChoice(None), - Some("auto") => ToolChoice(Some(ToolType::OneOf)), - Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))), - None => ToolChoice(Some(ToolType::OneOf)), + ToolTypeDeserializer::String(s) => match s.as_str() { + "none" => ToolChoice(Some(ToolType::NoTool)), + "auto" => ToolChoice(Some(ToolType::OneOf)), + _ => ToolChoice(Some(ToolType::FunctionName(s))), }, - ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)), + ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)), } } } diff --git a/router/src/server.rs b/router/src/server.rs index d3a280ca3c8..c56c39a3625 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,7 +24,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -1192,39 +1192,33 @@ async fn chat_completions( .as_secs(); let (tool_calls, output) = if tool_grammar.is_some() { - // gen_text should be valid json - let gen_text_value: Value = - serde_json::from_str(&generation.generated_text).map_err(|e| { - ( - StatusCode::UNPROCESSABLE_ENTITY, - Json(ErrorResponse { - error: e.to_string(), - error_type: "Input validation error".to_string(), - }), - ) - })?; + let gen_text_value: Value = serde_json::from_str(&generation.generated_text) + .map_err(|e| InferError::ToolError(e.to_string()))?; + + let function = gen_text_value.get("function").ok_or(InferError::ToolError( + "No function found in generated text".to_string(), + ))?; + + let name = function + .get("_name") + .and_then(Value::as_str) + .ok_or(InferError::ToolError( + "No _name found in generated text".to_string(), + ))? + .to_string(); + + let mut arguments = function.clone(); + if let Value::Object(ref mut props) = arguments { + props.remove("_name"); + } + let tool_calls = vec![ToolCall { id: "0".to_string(), r#type: "function".to_string(), function: FunctionDefinition { description: None, - name: gen_text_value - .get("function") - .and_then(|f| f.get("_name")) - .and_then(|name| name.as_str()) - .unwrap_or("default_function_name") - .to_string(), - // Serialize the JSON object obtained from "function" to an escaped JSON string - arguments: gen_text_value - .get("function") - .map(|f| { - let mut f_cloned = f.clone(); - if let Value::Object(ref mut props) = f_cloned { - props.remove("_name"); - } - f_cloned - }) - .unwrap_or_default(), + name, + arguments, }, }]; (Some(tool_calls), None) @@ -1498,6 +1492,7 @@ pub async fn run( ToolCall, Function, FunctionDefinition, + ToolChoice, ) ), tags( From e52be9bba2a11f6b2145fcd21be0ba860b977a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 19 Jul 2024 17:23:20 +0200 Subject: [PATCH 10/14] Add support for Deepseek V2 (#2224) Deepseek V2 is a MoE model from Deepseek. Relevant variations compared to other models: - Grouped top-K in expert selection. - mscale in yarn is calculated using the `mscale` and `mscale_all_dim` configuration options. - `mscale_all_dim` is also used in scaling attention softmax. - Permuting of the query/key representations before applying rotary embeddings. - Some projections cannot be sharded (`q_a_proj`, `kv_a_proj_with_mqa`). So, we need weight loads that supports quantized weights. To this end `{Weights,WeightLoader}.get_weight` was added. - The query/key head dimensionality differs from that of the value, so we need to pad during attention. - Heads with size 192, needs an extension to our paged attention fork and we need to ensure that the KV cache is allocated with the correct size. - Shared experts. --- docs/source/supported_models.md | 1 + .../test_flash_deepseek_v2.json | 89 ++ .../test_flash_deepseek_v2_all_params.json | 89 ++ .../test_flash_deepseek_v2_load.json | 358 +++++++ .../models/test_flash_deepseek_v2.py | 63 ++ server/Makefile-vllm | 6 +- server/text_generation_server/layers/exl2.py | 46 +- .../layers/gptq/__init__.py | 109 ++ .../text_generation_server/layers/marlin.py | 29 + .../text_generation_server/layers/rotary.py | 27 +- .../text_generation_server/models/__init__.py | 44 +- .../flash_deepseek_v2_modeling.py | 983 ++++++++++++++++++ .../models/flash_causal_lm.py | 10 +- .../text_generation_server/utils/weights.py | 13 + 14 files changed, 1826 insertions(+), 41 deletions(-) create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json create mode 100644 integration-tests/models/test_flash_deepseek_v2.py create mode 100644 server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 2bdd00de61d..bc124f3194c 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models on specific hardware ## Supported Models +- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) - [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal) - [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal) - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json new file mode 100644 index 00000000000..03f90367280 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.84375, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.34375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.8359375, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.0859375, + "special": false, + "text": " is" + }, + { + "id": 254, + "logprob": -1.5390625, + "special": false, + "text": " the" + }, + { + "id": 1022, + "logprob": -1.1875, + "special": false, + "text": " first" + }, + { + "id": 3458, + "logprob": -0.35546875, + "special": false, + "text": " step" + }, + { + "id": 279, + "logprob": -0.8828125, + "special": false, + "text": " in" + }, + { + "id": 254, + "logprob": -0.71484375, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is the first step in the" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json new file mode 100644 index 00000000000..e84135cf5e8 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 2143, + "logprob": -1.828125, + "special": false, + "text": " sent" + }, + { + "id": 10081, + "logprob": -0.36914062, + "special": false, + "text": " successfully" + }, + { + "id": 13, + "logprob": 0.0, + "special": false, + "text": "." + }, + { + "id": 185, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 1380, + "logprob": -0.38671875, + "special": false, + "text": "We" + }, + { + "id": 543, + "logprob": -0.12695312, + "special": false, + "text": " will" + }, + { + "id": 752, + "logprob": -0.20117188, + "special": false, + "text": " get" + }, + { + "id": 279, + "logprob": 0.0, + "special": false, + "text": " in" + }, + { + "id": 5402, + "logprob": 0.0, + "special": false, + "text": " touch" + }, + { + "id": 366, + "logprob": 0.0, + "special": false, + "text": " with" + } + ], + "top_tokens": null + }, + "generated_text": "Test request sent successfully.\nWe will get in touch with" +} diff --git a/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json new file mode 100644 index 00000000000..a4b9784a321 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_deepseek_v2/test_flash_deepseek_v2_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.1875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 100000, + "logprob": null, + "text": "<|begin▁of▁sentence|>" + }, + { + "id": 3533, + "logprob": -9.625, + "text": "Test" + }, + { + "id": 3102, + "logprob": -11.25, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 185, + "logprob": -1.5546875, + "special": false, + "text": "\n" + }, + { + "id": 549, + "logprob": -2.8125, + "special": false, + "text": "The" + }, + { + "id": 1727, + "logprob": -2.375, + "special": false, + "text": " test" + }, + { + "id": 3102, + "logprob": -0.890625, + "special": false, + "text": " request" + }, + { + "id": 317, + "logprob": -1.1484375, + "special": false, + "text": " is" + }, + { + "id": 245, + "logprob": -1.5390625, + "special": false, + "text": " a" + }, + { + "id": 3102, + "logprob": -2.609375, + "special": false, + "text": " request" + }, + { + "id": 327, + "logprob": -0.75, + "special": false, + "text": " for" + }, + { + "id": 245, + "logprob": -1.1171875, + "special": false, + "text": " a" + }, + { + "id": 1727, + "logprob": -0.90625, + "special": false, + "text": " test" + } + ], + "top_tokens": null + }, + "generated_text": "\nThe test request is a request for a test" + } +] diff --git a/integration-tests/models/test_flash_deepseek_v2.py b/integration-tests/models/test_flash_deepseek_v2.py new file mode 100644 index 00000000000..010e08c9059 --- /dev/null +++ b/integration-tests/models/test_flash_deepseek_v2.py @@ -0,0 +1,63 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_deepseek_v2_handle(launcher): + with launcher("deepseek-ai/DeepSeek-V2-Lite", num_shard=2) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_deepseek_v2(flash_deepseek_v2_handle): + await flash_deepseek_v2_handle.health(300) + return flash_deepseek_v2_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_all_params(flash_deepseek_v2, response_snapshot): + response = await flash_deepseek_v2.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_deepseek_v2_load( + flash_deepseek_v2, generate_load, response_snapshot +): + responses = await generate_load( + flash_deepseek_v2, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot diff --git a/server/Makefile-vllm b/server/Makefile-vllm index 2f2b5ef6807..f1f805290e2 100644 --- a/server/Makefile-vllm +++ b/server/Makefile-vllm @@ -1,14 +1,14 @@ -commit_cuda := b5dfc61db88a81069e45b44f7cc99bd9e62a60fa +commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921 build-vllm-cuda: if [ ! -d 'vllm' ]; then \ pip install -U ninja packaging --no-cache-dir && \ git clone https://github.com/Narsil/vllm.git vllm; \ fi - cd vllm && git fetch && git checkout $(commit_cuda) && python setup.py build + cd vllm && git fetch origin && git checkout $(commit_cuda) && python setup.py build install-vllm-cuda: build-vllm-cuda - cd vllm && git fetch && git checkout $(commit_cuda) && pip install -e . + cd vllm && git fetch origin && git checkout $(commit_cuda) && pip install -e . build-vllm-rocm: if [ ! -d 'vllm' ]; then \ diff --git a/server/text_generation_server/layers/exl2.py b/server/text_generation_server/layers/exl2.py index 1a5acfa858d..a6e07f45343 100644 --- a/server/text_generation_server/layers/exl2.py +++ b/server/text_generation_server/layers/exl2.py @@ -34,15 +34,10 @@ def get_linear(self, bias: torch.Tensor): class Exl2WeightsLoader(WeightsLoader): """Loader for exl2-quantized weights.""" - def get_weights_col_packed( - self, - weights: Weights, - prefix: str, - block_sizes: Union[int, List[int]], - ): - raise RuntimeError("Column-packed weights are not supported for exl") - - def get_weights_col(self, weights: Weights, prefix: str): + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ try: q_weight = weights.get_tensor(f"{prefix}.q_weight") except RuntimeError: @@ -63,26 +58,21 @@ def get_weights_col(self, weights: Weights, prefix: str): q_groups=q_groups, ) + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + raise RuntimeError("Column-packed weights are not supported for exl") + + def get_weights_col(self, weights: Weights, prefix: str): + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) + def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int): raise ValueError("get_multi_weights_col is not supported for exl2") def get_weights_row(self, weights: Weights, prefix: str): - try: - q_weight = weights.get_tensor(f"{prefix}.q_weight") - except RuntimeError: - raise RuntimeError( - "Cannot load `exl2`-quantized weight, make sure the model is already quantized." - ) - - q_scale = weights.get_tensor(f"{prefix}.q_scale") - q_invperm = weights.get_tensor(f"{prefix}.q_invperm") - q_scale_max = weights.get_tensor(f"{prefix}.q_scale_max") - q_groups = weights.get_tensor(f"{prefix}.q_groups") - - return Exl2Weight( - q_weight=q_weight, - q_scale=q_scale, - q_invperm=q_invperm, - q_scale_max=q_scale_max, - q_groups=q_groups, - ) + # Sharding is not yet supported, so we return the weights as-is. + return self.get_weights(weights, prefix) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index c98dbefc43c..6feca27559f 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -134,6 +134,115 @@ def __init__( self.quantize = quantize self.sym = sym + def get_weights(self, weights: Weights, prefix: str): + from text_generation_server.layers.marlin import ( + can_use_gptq_marlin, + repack_gptq_for_marlin, + ) + + self._get_gptq_params(weights) + if can_use_gptq_marlin( + bits=self.bits, + groupsize=self.groupsize, + quant_method=self.quant_method, + quantize=self.quantize, + sym=self.sym, + ): + log_once(logger.info, "Using GPTQ-Marlin kernels") + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized" + ) + + g_idx = weights.get_tensor(f"{prefix}.g_idx") + scales = weights.get_tensor(f"{prefix}.scales") + + return repack_gptq_for_marlin( + qweight=qweight, + scales=scales, + g_idx=g_idx, + bits=self.bits, + desc_act=self.desc_act, + groupsize=self.groupsize, + sym=self.sym, + sharded_infeatures=False, + ) + + use_exllama = True + if self.bits != 4: + use_exllama = False + + if self.desc_act: + log_once(logger.warning, "Disabling exllama because desc_act=True") + use_exllama = False + + try: + qweight = weights.get_tensor(f"{prefix}.qweight") + except RuntimeError: + raise RuntimeError( + "Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" + ) + + if self.quantize == "gptq" and self.quant_method == "gptq": + g_idx = weights.get_tensor(f"{prefix}.g_idx") + else: + g_idx = None + + from text_generation_server.layers.gptq import ( + HAS_EXLLAMA, + CAN_EXLLAMA, + GPTQWeight, + ) + + if use_exllama: + if not HAS_EXLLAMA: + if CAN_EXLLAMA: + log_once( + logger.warning, + "Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True", + ) + use_exllama = False + else: + log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}") + + qzeros = weights.get_tensor(f"{prefix}.qzeros") + scales = weights.get_tensor(f"{prefix}.scales") + + if use_exllama and g_idx is not None: + g_idx = g_idx - g_idx[0] + + if self.quantize == "gptq" and self.quant_method == "awq": + log_once( + logger.info, "Converting AWQ model to Exllama/GPTQ packing format." + ) + from text_generation_server.layers.awq.conversion_utils import ( + fast_awq_to_gptq, + ) + + qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) + if use_exllama: + g_idx = None + else: + g_idx = ( + torch.arange( + qweight.shape[0] * (32 // self.bits), + device=qweight.device, + ) + // self.groupsize + ).to(dtype=torch.int32) + + return GPTQWeight( + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + bits=self.bits, + groupsize=self.groupsize, + use_exllama=use_exllama, + ) + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index e7f017a4dac..a913ff57a2d 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -33,6 +33,35 @@ def __init__(self, *, bits: int, is_marlin_24: bool): self.bits = bits self.is_marlin_24 = is_marlin_24 + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24" + if is_marlin_24: + try: + B = weights.get_tensor(f"{prefix}.B_24") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized." + ) + + B_meta = weights.get_tensor(f"{prefix}.B_meta") + s = weights.get_tensor(f"{prefix}.s") + weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits) + else: + try: + B = weights.get_tensor(f"{prefix}.B") + except RuntimeError: + raise RuntimeError( + "Cannot load `marlin` weight, make sure the model is already quantized." + ) + + s = weights.get_tensor(f"{prefix}.s") + weight = MarlinWeight(B=B, s=s) + + return weight + def get_weights_col_packed( self, weights: Weights, diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 8c354b82b41..db78ee1cc94 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -1,6 +1,7 @@ import os import torch from torch import nn +from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -97,6 +98,8 @@ def static(cls, config, dim, base, device): ) elif rope_scaling["type"] == "yarn": scaling_factor = rope_scaling["factor"] + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -109,6 +112,8 @@ def static(cls, config, dim, base, device): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) elif rope_scaling["type"] in ["su", "longrope"]: short_factor = torch.tensor( @@ -181,6 +186,8 @@ def load(cls, config, prefix, weights): scaling_factor=scaling_factor, ) elif rope_scaling["type"] == "yarn": + mscale = rope_scaling.get("mscale", 1.0) + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0) return YarnPositionRotaryEmbedding( dim=2 * inv_freq.shape[0], max_position_embeddings=rope_scaling[ @@ -193,6 +200,8 @@ def load(cls, config, prefix, weights): attn_factor=1, beta_fast=32, beta_slow=1, + mscale=mscale, + mscale_all_dim=mscale_all_dim, ) else: raise NotImplementedError( @@ -346,10 +355,10 @@ def linear_ramp_mask(min, max, dim): return ramp_func -def get_mscale(scale=1): +def get_mscale(scale: float = 1.0, mscale: float = 1.0): if scale <= 1: return 1.0 - return 0.1 * math.log(scale) + 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 class YarnPositionRotaryEmbedding(PositionRotaryEmbedding): @@ -365,6 +374,8 @@ def __init__( attn_factor, beta_fast, beta_slow, + mscale: float, + mscale_all_dim: float, ): inv_freq = _create_inv_freq(dim, base, device) super().__init__(inv_freq, scaling_factor) @@ -375,8 +386,12 @@ def __init__( self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow + self.mscale_all_dim = mscale_all_dim + self.scaling_factor = scaling_factor self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor + get_mscale(self.scaling_factor, mscale) + / get_mscale(self.scaling_factor, mscale_all_dim) + * self.attn_factor ) # Get n-d magnitude scaling corrected for interpolation def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -387,7 +402,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): or self._cos_cached.device != device or self._cos_cached.dtype != dtype ): - if seqlen > self.max_position_embeddings: + if seqlen > self.max_position_embeddings or True: inv_freq_extrapolation = _create_inv_freq( self.dim, self.base, self.inv_freq.device ) @@ -400,6 +415,7 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): self.base, self.max_position_embeddings, ) + inv_freq_mask = ( 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation @@ -409,9 +425,6 @@ def _update_cos_sin_cache(self, dtype, device, seqlen): ) self.inv_freq = inv_freq - self.mscale = float( - get_mscale(self.scaling_factor) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32156a068f7..690a8887675 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -61,6 +61,10 @@ try: from text_generation_server.models.flash_causal_lm import FlashCausalLM from text_generation_server.models.vlm_causal_lm import VlmCausalLM + from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import ( + FlashDeepseekV2ForCausalLM, + DeepseekV2Config, + ) from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) @@ -141,6 +145,11 @@ class ModelType(enum.Enum): + DEEPSEEK_V2 = { + "type": "deepseek_v2", + "name": "Deepseek V2", + "url": "https://huggingface.co/deepseek-ai/DeepSeek-V2", + } IDEFICS2 = { "type": "idefics2", "name": "Idefics 2", @@ -459,7 +468,40 @@ def get_model( f"The backend {SYSTEM} does not support sliding window attention that is used by the model type {model_type}. To use this model nonetheless with the {SYSTEM} backend, please launch TGI with the argument `--max-input-tokens` smaller than sliding_window={sliding_window} (got here max_input_tokens={max_input_tokens})." ) - if model_type == MAMBA: + if model_type == DEEPSEEK_V2: + if FLASH_ATTENTION: + head_size = max( + config_dict.get("qk_nope_dim", 128) + + config_dict.get("qk_rope_dim", 64), + config_dict.get("v_head_dim", 128), + ) + return FlashCausalLM( + model_id=model_id, + model_class=FlashDeepseekV2ForCausalLM, + revision=revision, + quantize=quantize, + speculator=speculator, + default_dtype=torch.bfloat16, + dtype=dtype, + trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DeepseekV2Config, + head_size=head_size, + ) + elif sharded: + raise NotImplementedError( + FLASH_ATT_ERROR_MESSAGE.format("Sharded Deepseek V2") + ) + else: + return CausalLM.fallback( + model_id, + revision, + quantize=quantize, + speculator=speculator, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) + elif model_type == MAMBA: return Mamba( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py new file mode 100644 index 00000000000..8ffbd143c02 --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -0,0 +1,983 @@ +# coding=utf-8 +# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed +from text_generation_server.layers import ( + FastLinear, + SpeculativeHead, + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelRowLinear, + get_linear, +) +from text_generation_server.layers.attention import ( + attention, + paged_attention, + reshape_and_cache, +) +from text_generation_server.layers.attention.common import Seqlen +from text_generation_server.layers.layernorm import FastRMSNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale +from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_once +from text_generation_server.utils.weights import Weights +from torch import nn +from transformers.activations import ACT2FN +from transformers.configuration_utils import PretrainedConfig + + +class DeepseekV2Config(PretrainedConfig): + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size=1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts=2, + n_routed_experts=160, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=8, + topk_group=3, + num_experts_per_tok=6, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Deepseek V2 models." + ) + + if ep_size != 1: + raise ValueError( + f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}" + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def _load_experts(config, prefix: str, mat: str, weights: Weights): + if config.quantize is not None: + raise NotImplementedError( + "Deepseek V2 does not support weight quantization yet." + ) + + assert mat in ["gate_proj", "up_proj", "down_proj"] + + world_size = weights.process_group.size() + rank = weights.process_group.rank() + + assert ( + config.moe_intermediate_size % world_size == 0 + ), f"The chosen size {config.moe_intermediate_size} is not compatible with sharding on {world_size} shards" + + block_size = config.moe_intermediate_size // world_size + start = rank * block_size + stop = (rank + 1) * block_size + + tensor = torch.empty( + (config.n_routed_experts * block_size, config.hidden_size), + dtype=weights.dtype, + device=weights.device, + ) + + for i in range(config.n_routed_experts): + slice_ = weights._get_slice(f"{prefix}.{i}.{mat}.weight") + + if mat == "down_proj": + expert_slice = slice_[:, start:stop].t().contiguous() + else: + expert_slice = slice_[start:stop] + tensor[i * block_size : (i + 1) * block_size] = expert_slice.to( + dtype=weights.dtype + ).to(device=weights.device) + return tensor + + +class DeepseekV2Attention(torch.nn.Module): + def __init__( + self, + prefix: str, + config, + weights: Weights, + ): + super().__init__() + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.kv_lora_rank = config.kv_lora_rank + self.q_lora_rank = config.q_lora_rank + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim + self.value_head_size = config.v_head_dim + self.head_pad_size = max(self.head_size, self.value_head_size) + + self.rotary_emb = PositionRotaryEmbedding.static( + config=config, + dim=self.qk_rope_head_dim, + base=config.rope_theta, + device=weights.device, + ) + + mscale = get_mscale( + self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim + ) + self.softmax_scale = self.head_size**-0.5 * mscale * mscale + + if self.num_heads % weights.process_group.size() != 0: + raise ValueError( + f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " + f"and `num_shards`: {weights.process_group.size()}" + ) + self.num_heads = self.num_heads // weights.process_group.size() + self.num_key_value_heads = ( + config.num_key_value_heads // weights.process_group.size() + ) + + if self.q_lora_rank is None: + self.q_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_proj", + weights=weights, + bias=config.attention_bias, + ) + else: + self.q_a_proj = get_linear( + weight=weights.get_weights(f"{prefix}.q_a_proj"), + bias=( + weights.get_tensor(f"{prefix}.q_a_proj.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + self.q_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.q_a_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + self.q_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.q_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.kv_a_proj_with_mqa = get_linear( + weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"), + bias=( + weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias") + if config.attention_bias + else None + ), + quantize=config.quantize, + ) + + self.kv_a_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps + ) + + self.kv_b_proj = TensorParallelColumnLinear.load( + config, + prefix=f"{prefix}.kv_b_proj", + weights=weights, + bias=config.attention_bias, + ) + + self.o_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ) + self.num_groups = self.num_heads // self.num_key_value_heads + self.kv_head_mapping = torch.arange( + 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device + ).repeat_interleave(self.num_groups) + + def forward( + self, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + if self.q_lora_rank is None: + query = self.q_proj(hidden_states) + else: + query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0]) + query = query.view(-1, self.num_heads, self.head_size) + + _, query_pe = torch.split( + query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, key_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view( + -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size + ) + + key_nope, value = torch.split( + kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1 + ) + + batch_size, heads, head_dim = query_pe.shape + query_pe = ( + query_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + batch_size, heads, head_dim = key_pe.shape + key_pe = ( + key_pe.view(batch_size, heads, head_dim // 2, 2) + .transpose(2, 3) + .reshape(batch_size, heads, head_dim) + ) + self.rotary_emb(query_pe, key_pe, cos, sin) + + query[..., self.qk_nope_head_dim :] = query_pe + key = torch.empty_like(query) + key[..., : self.qk_nope_head_dim] = key_nope + key[..., self.qk_nope_head_dim :] = key_pe + + # We need to pad the heads because Flash Attention does not support + # qk and v with different head sizes. + query = torch.nn.functional.pad( + query, (0, self.head_pad_size - self.head_size), value=0 + ) + key = torch.nn.functional.pad( + key, (0, self.head_pad_size - self.head_size), value=0 + ) + value = torch.nn.functional.pad( + value, (0, self.head_pad_size - self.value_head_size), value=0 + ) + + reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) + + # Output tensor + attn_output = torch.empty_like(query) + + # Prefill + if cu_seqlen_prefill is not None: + # flash attention + attention( + query, + key, + value, + attn_output, + cu_seqlen_prefill, + max_s, + self.softmax_scale, + ) + # Decode + else: + paged_attention( + attn_output, + query, + kv_cache[0], + kv_cache[1], + self.kv_head_mapping, + self.softmax_scale, + block_tables, + input_lengths, + max_s, + ) + + # Remove padding. + attn_output = attn_output[..., : self.value_head_size] + + return self.o_proj( + attn_output.reshape(-1, self.num_heads * self.value_head_size) + ) + + +class DeepseekV2MLP(nn.Module): + def __init__(self, prefix: str, config, weights, intermediate_size: int): + super().__init__() + self.hidden_act = config.hidden_act + if self.hidden_act != "silu": + # Bail out because MoE only supports silu. + raise NotImplementedError( + "Currently only `silu` is supported as an activation for Deepseek V2." + ) + self.act = ACT2FN[self.hidden_act] + + self.gate_up_proj = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], + weights=weights, + dim=0, + bias=False, + ) + + self.down_proj = TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, + ) + + self.intermediate_size = intermediate_size // weights.process_group.size() + + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + + def forward(self, hidden_states: torch.Tensor, reduce: bool = True): + if ( + SYSTEM == "rocm" + and self.hidden_act == "silu" + and hidden_states.shape[0] == 1 + and not self.quantize + ): + out = torch.empty( + hidden_states.shape[0], + self.intermediate_size, + dtype=hidden_states.dtype, + device="cuda", + ) + _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) + return self.down_proj(out, reduce=reduce) + else: + gate_up_states = self.gate_up_proj(hidden_states) + gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce + ) + + +class BlockSparseMoE(nn.Module): + def __init__(self, prefix, config: DeepseekV2Config, weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = ( + config.moe_intermediate_size // weights.process_group.size() + ) + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + gate_proj = _load_experts( + config, f"{prefix}.experts", "gate_proj", weights + ).view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + + up_proj = _load_experts(config, f"{prefix}.experts", "up_proj", weights).view( + self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim + ) + + self.gate_up_proj = torch.cat([gate_proj, up_proj], dim=1) + + self.down_proj = ( + _load_experts(config, f"{prefix}.experts", "down_proj", weights) + .view(self.n_routed_experts, self.moe_intermediate_size, self.hidden_dim) + .transpose(1, 2) + .contiguous() + ) + + # Gating + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + router_logits = self.gate(x) + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + out = ( + fused_experts( + x, + self.gate_up_proj, + self.down_proj, + topk_weights, + topk_ids, + inplace=True, + ) + * self.routed_scaling_factor + ) + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out.view(*x.shape) + + +class DenseMoE(nn.Module): + def __init__(self, prefix: str, config: DeepseekV2Config, weights: Weights): + super().__init__() + + self.hidden_dim = config.hidden_size + self.moe_intermediate_size = config.moe_intermediate_size + self.n_routed_experts = config.n_routed_experts + self.n_expert_group = config.n_group + self.topk_group = config.topk_group + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + + # Gating + # + # Seems like no one quantizes the gate. + self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False) + + self.experts = [ + DeepseekV2MLP( + f"{prefix}.experts.{i}", config, weights, self.moe_intermediate_size + ) + for i in range(self.n_routed_experts) + ] + + if config.n_shared_experts is not None: + self.shared_experts = DeepseekV2MLP( + prefix=f"{prefix}.shared_experts", + config=config, + weights=weights, + intermediate_size=config.moe_intermediate_size + * config.n_shared_experts, + ) + else: + self.shared_experts = None + + self.process_group = weights.process_group + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + # optional reshape + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + if self.shared_experts is not None: + shared_output = self.shared_experts(x, reduce=False) + else: + shared_output = None + + # gate_logits: (sequence_length, n_experts) + router_logits = self.gate(x) + + topk_weights, topk_ids = grouped_topk( + x, + router_logits, + self.top_k, + renormalize=self.norm_topk_prob, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + ) + + out = self.moe_infer_gpu(x, topk_ids, topk_weights) * self.routed_scaling_factor + + if shared_output is not None: + out = out + shared_output + + # Reduce sum + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) + + return out + + def moe_infer_gpu( + self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor + ): + weights = torch.zeros( + topk_ids.shape[0], len(self.experts), dtype=x.dtype, device=x.device + ) + weights.scatter_(1, topk_ids, topk_weight) + + out = x.new_zeros(x.shape[0], self.hidden_dim) + for i, expert in enumerate(self.experts): + # Add expert output to out with masking + out += expert(x, reduce=False) * weights[:, i].view(-1, 1) + return out + + +class DeepseekV2Layer(nn.Module): + def __init__(self, prefix, layer_id, config, weights): + super().__init__() + prefix = f"{prefix}.layers.{layer_id}" + + self.self_attn = DeepseekV2Attention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + ) + + if ( + config.n_routed_experts is not None + and layer_id >= config.first_k_dense_replace + and layer_id % config.moe_layer_freq == 0 + ): + moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE + self.mlp = moe_cls(f"{prefix}.mlp", config, weights) + else: + self.mlp = DeepseekV2MLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + intermediate_size=config.intermediate_size, + ) + + self.input_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = FastRMSNorm.load( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + cu_seqlen_prefill: torch.Tensor, + kv_cache, + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: Seqlen, + max_s: int, + ): + normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + attn_output = self.self_attn( + normed_hidden_states, + cos, + sin, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + + # faster post attention rms norm + normed_attn_res_output, residual = self.post_attention_layernorm( + attn_output, residual + ) + + output = self.mlp(normed_attn_res_output) + + return output, residual + + +class DeepseekV2Model(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.embed_tokens = TensorParallelEmbedding( + prefix=f"{prefix}.embed_tokens", weights=weights + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Layer( + prefix, + layer_id, + config, + weights, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = FastRMSNorm.load( + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps + ) + + self.head_size = self.layers[0].self_attn.head_size + self.num_heads = self.layers[0].self_attn.num_heads + self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + # Get rotary cos and sin for this forward + # Avoid to index in each layer + cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( + position_ids, max_s, hidden_states.dtype + ) + + residual = None + for i, layer in enumerate(self.layers): + hidden_states, residual = layer( + hidden_states, + residual, + cos, + sin, + cu_seqlen_prefill, + kv_cache[i], + block_tables, + slots, + input_lengths, + max_s, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class FlashDeepseekV2ForCausalLM(torch.nn.Module): + def __init__(self, prefix: str, config, weights: Weights): + super().__init__() + + self.model = DeepseekV2Model( + "model" if not prefix else f"{prefix}.model", config, weights + ) + self.lm_head = SpeculativeHead.load( + config, + prefix="lm_head" if not prefix else f"{prefix}.lm_head", + weights=weights, + ) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlen_prefill: Optional[torch.Tensor], + kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], + block_tables: torch.Tensor, + slots: torch.Tensor, + input_lengths: torch.Tensor, + max_s: int, + prefill_cache_indices: Optional[torch.Tensor], + lm_head_indices: Optional[torch.Tensor] = None, + adapter_data: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self.model( + input_ids, + position_ids, + cu_seqlen_prefill, + kv_cache, + block_tables, + slots, + input_lengths, + max_s, + ) + if lm_head_indices is not None: + hidden_states = hidden_states[lm_head_indices] + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits + + +# Functions below are from vLLM: +# +# https://github.com/vllm-project/vllm/blob/f7160d946a0a07703e72d81ba9ecf3913f192605/vllm/model_executor/layers/fused_moe/fused_moe.py#L397 +# +# Remove after we have synced our version with upstream. + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], +) -> Dict[str, int]: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } + return config + + +def fused_experts( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] + + import triton.language as tl + from vllm import _custom_ops as ops + from vllm.model_executor.layers.fused_moe.fused_moe import ( + get_moe_configs, + invoke_fused_moe_kernel, + moe_align_block_size, + ) + + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config( + M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None + ) + + intermediate_cache1 = torch.empty( + (M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + intermediate_cache3 = torch.empty( + (M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config["BLOCK_SIZE_M"], E + ) + compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 + + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + if inplace: + return torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states, + ) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2ca9eef33da..2888f1f79e3 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -839,7 +839,9 @@ def __init__( default_dtype=torch.float16, aliases=None, # Used for Santacoder override of config - num_kv_heads=None, + num_kv_heads: Optional[int] = None, + # Deepseek V2 uses different QK and V dims. + head_size: Optional[int] = None, skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() @@ -922,7 +924,11 @@ def __init__( else num_kv_heads ) assert self.num_kv_heads > 0 - self.head_size = config.hidden_size // config.num_attention_heads + + if head_size is None: + self.head_size = config.hidden_size // config.num_attention_heads + else: + self.head_size = head_size self.cuda_graphs = {} self.kv_cache = [] diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 6876b70096a..91592df090b 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -21,6 +21,13 @@ class WeightsLoader(ABC): with the format, etc. """ + @abstractmethod + def get_weights(self, weights: "Weights", prefix: str): + """ + Get weights at the given prefix and apply without tensor paralllism. + """ + ... + @abstractmethod def get_weights_col_packed( self, @@ -104,6 +111,9 @@ def __init__(self, weight_class): and/or concatenation. """ + def get_weights(self, weights: "Weights", prefix: str): + return weights.get_tensor(f"{prefix}.weight") + def get_weights_col_packed( self, weights: "Weights", @@ -299,6 +309,9 @@ def get_packed_sharded( return tensor + def get_weights(self, prefix: str): + return self.weights_loader.get_weights(self, prefix) + def get_weights_col_packed_qkv( self, prefix: str, From 11123a8e992d4595612cab33085e716f6373b760 Mon Sep 17 00:00:00 2001 From: Adrien Date: Sat, 20 Jul 2024 07:06:40 +0200 Subject: [PATCH 11/14] re-push to internal registry (#2242) * re-push to internal registry Signed-off-by: Adrien * fix name Signed-off-by: Adrien * debug Signed-off-by: Adrien * debug Signed-off-by: Adrien * wip Signed-off-by: Adrien * wip Signed-off-by: Adrien * wip debug Signed-off-by: Adrien * add debug Signed-off-by: Adrien * should Signed-off-by: Adrien * wip Signed-off-by: Adrien * ww Signed-off-by: Adrien * wip Signed-off-by: Adrien * wip Signed-off-by: Adrien * ww Signed-off-by: Adrien * wip Signed-off-by: Adrien * wip Signed-off-by: Adrien * debug Signed-off-by: Adrien * w Signed-off-by: Adrien * revert tests Signed-off-by: Adrien * last reverts Signed-off-by: Adrien * another one Signed-off-by: Adrien --------- Signed-off-by: Adrien --- .github/workflows/build.yaml | 28 +++++++++++++++++++--------- .github/workflows/load_test.yaml | 3 ++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index cd9f19ba02f..abe161db0ec 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -27,8 +27,8 @@ jobs: concurrency: group: ${{ github.workflow }}-build-and-push-image-${{ inputs.hardware }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - # TODO see with @Glegendre to get CPU runner here instead - runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + runs-on: + group: aws-r7i-8xlarge-priv permissions: contents: write packages: write @@ -49,7 +49,7 @@ jobs: export dockerfile="Dockerfile" export label_extension="" export docker_devices="" - export runs_on="nvidia-gpu" + export runs_on="aws-g5-12xlarge" ;; rocm) export dockerfile="Dockerfile_amd" @@ -79,9 +79,15 @@ jobs: uses: docker/setup-buildx-action@v3 with: install: true - config-inline: | + buildkitd-config-inline: | [registry."docker.io"] - mirrors = ["registry.github-runners.huggingface.tech"] + mirrors = ["registry-us-east-1-mirror.prod.aws.ci.huggingface.tech"] + - name: Login to internal Container Registry + uses: docker/login-action@v3 + with: + username: ${{ secrets.REGISTRY_USERNAME }} + password: ${{ secrets.REGISTRY_PASSWORD }} + registry: registry.internal.huggingface.tech - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v3 @@ -103,7 +109,8 @@ jobs: uses: docker/metadata-action@v5 with: images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inference tags: | type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}${{ env.LABEL }} # If main, release or tag @@ -115,7 +122,8 @@ jobs: flavor: | latest=auto images: | - registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference + registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference + registry.internal.huggingface.tech/api-inference/community/text-generation-inferenceca ghcr.io/huggingface/text-generation-inference db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference tags: | @@ -141,7 +149,7 @@ jobs: - name: Final id: final run: | - echo "docker_image=registry-push.github-runners.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" + echo "docker_image=registry-us-east-1.prod.aws.ci.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT}}${{ env.LABEL }}" >> "$GITHUB_OUTPUT" echo "docker_devices=${{ env.DOCKER_DEVICES }}" >> "$GITHUB_OUTPUT" echo "runs_on=${{ env.RUNS_ON }}" >> "$GITHUB_OUTPUT" echo "label=${{ env.LABEL }}" >> "$GITHUB_OUTPUT" @@ -150,7 +158,8 @@ jobs: group: ${{ github.workflow }}-${{ github.job }}-${{ needs.build-and-push.outputs.label }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true needs: build-and-push - runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] + runs-on: + group: ${{ needs.build-and-push.outputs.runs_on }} if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} @@ -174,3 +183,4 @@ jobs: export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE pytest -s -vv integration-tests ${PYTEST_FLAGS} + diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 637df4727e1..ecfe0fdaaf6 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -15,7 +15,8 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true - runs-on: [self-hosted, nvidia-gpu , multi-gpu, 4-a10, ci] + runs-on: + group: aws-g5-12xlarge env: DOCKER_VOLUME: /cache steps: From e5c1d6d61110b0d335830d5befe1e36cb9946eef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Sat, 20 Jul 2024 12:26:06 +0200 Subject: [PATCH 12/14] Add FP8 release test (#2261) --- .../test_flash_llama_fp8.json | 89 +++++ .../test_flash_llama_fp8_all_params.json | 89 +++++ .../test_flash_llama_fp8_load.json | 358 ++++++++++++++++++ .../models/test_flash_llama_fp8.py | 62 +++ 4 files changed, 598 insertions(+) create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json create mode 100644 integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json create mode 100644 integration-tests/models/test_flash_llama_fp8.py diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json new file mode 100644 index 00000000000..85cfb91f1ec --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7900391, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3554688, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0039062, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.4489746, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8100586, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013015747, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json new file mode 100644 index 00000000000..dcb4d063a27 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_all_params.json @@ -0,0 +1,89 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": 0, + "tokens": [ + { + "id": 25, + "logprob": -0.8535156, + "special": false, + "text": ":" + }, + { + "id": 2209, + "logprob": -2.4804688, + "special": false, + "text": " Is" + }, + { + "id": 279, + "logprob": -0.7167969, + "special": false, + "text": " the" + }, + { + "id": 734, + "logprob": -2.625, + "special": false, + "text": " function" + }, + { + "id": 330, + "logprob": -0.35131836, + "special": false, + "text": " \"" + }, + { + "id": 4110, + "logprob": -2.4101562, + "special": false, + "text": "Create" + }, + { + "id": 264, + "logprob": -0.23181152, + "special": false, + "text": " a" + }, + { + "id": 502, + "logprob": -0.25512695, + "special": false, + "text": " new" + }, + { + "id": 1052, + "logprob": -1.2792969, + "special": false, + "text": " file" + }, + { + "id": 1, + "logprob": -1.2529297, + "special": false, + "text": "\"" + } + ], + "top_tokens": null + }, + "generated_text": "Test request: Is the function \"Create a new file\"" +} diff --git a/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json new file mode 100644 index 00000000000..36c87c0975a --- /dev/null +++ b/integration-tests/models/__snapshots__/test_flash_llama_fp8/test_flash_llama_fp8_load.json @@ -0,0 +1,358 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 128000, + "logprob": null, + "text": "<|begin_of_text|>" + }, + { + "id": 2323, + "logprob": -9.421875, + "text": "Test" + }, + { + "id": 1715, + "logprob": -10.546875, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 369, + "logprob": -2.1816406, + "special": false, + "text": " for" + }, + { + "id": 279, + "logprob": -2.6992188, + "special": false, + "text": " the" + }, + { + "id": 220, + "logprob": -3.6308594, + "special": false, + "text": " " + }, + { + "id": 679, + "logprob": -1.7988281, + "special": false, + "text": "201" + }, + { + "id": 24, + "logprob": -1.3535156, + "special": false, + "text": "9" + }, + { + "id": 12, + "logprob": -2.0058594, + "special": false, + "text": "-" + }, + { + "id": 2366, + "logprob": -0.45410156, + "special": false, + "text": "202" + }, + { + "id": 15, + "logprob": -0.037109375, + "special": false, + "text": "0" + }, + { + "id": 2978, + "logprob": -0.8095703, + "special": false, + "text": " school" + }, + { + "id": 1060, + "logprob": -0.013053894, + "special": false, + "text": " year" + } + ], + "top_tokens": null + }, + "generated_text": " for the 2019-2020 school year" + } +] diff --git a/integration-tests/models/test_flash_llama_fp8.py b/integration-tests/models/test_flash_llama_fp8.py new file mode 100644 index 00000000000..fe5df590c2a --- /dev/null +++ b/integration-tests/models/test_flash_llama_fp8.py @@ -0,0 +1,62 @@ +import pytest + + +@pytest.fixture(scope="module") +def flash_llama_fp8_handle(launcher): + with launcher("meta-llama/Meta-Llama-3-8B", num_shard=2, quantize="fp8") as handle: + yield handle + + +@pytest.fixture(scope="module") +async def flash_llama_fp8(flash_llama_fp8_handle): + await flash_llama_fp8_handle.health(300) + return flash_llama_fp8_handle.client + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", max_new_tokens=10, decoder_input_details=True + ) + + assert response.details.generated_tokens == 10 + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot): + response = await flash_llama_fp8.generate( + "Test request", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response == response_snapshot + + +@pytest.mark.release +@pytest.mark.asyncio +@pytest.mark.private +async def test_flash_llama_fp8_load(flash_llama_fp8, generate_load, response_snapshot): + responses = await generate_load( + flash_llama_fp8, "Test request", max_new_tokens=10, n=4 + ) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + + assert responses == response_snapshot From 53ec0b790bb086be3772cb7da14c5f0f006105c4 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sat, 20 Jul 2024 17:02:04 +0000 Subject: [PATCH 13/14] feat(fp8): use fbgemm kernels and load fp8 weights directly (#2248) * feat(fp8): add support for fbgemm * allow loading fp8 weights directly * update outlines * fix makefile * build fbgemm * avoid circular import and fix dockerfile * add default dtype * refactored weights loader * fix auto conversion * fix quantization config parsing * force new nccl on install * missing get_weights implementation * increase timeout --- .github/workflows/build.yaml | 1 - Dockerfile | 17 +- server/Makefile | 7 +- server/Makefile-fbgemm | 15 + server/fbgemm_remove_unused.patch | 306 ++++++++++++++++++ server/fix_torch90a.sh | 11 + server/text_generation_server/cli.py | 11 +- .../layers/attention/rocm.py | 6 +- server/text_generation_server/layers/bnb.py | 16 +- server/text_generation_server/layers/eetq.py | 4 +- server/text_generation_server/layers/fp8.py | 183 ++++++++++- .../layers/gptq/exllamav2.py | 3 +- .../text_generation_server/layers/marlin.py | 13 +- .../text_generation_server/models/__init__.py | 23 +- .../custom_modeling/flash_llama_modeling.py | 50 ++- .../models/flash_causal_lm.py | 31 +- .../text_generation_server/models/globals.py | 12 +- server/text_generation_server/models/model.py | 11 +- .../models/vlm_causal_lm.py | 7 +- server/text_generation_server/utils/dist.py | 6 +- server/text_generation_server/utils/log.py | 13 +- .../utils/quantization.py | 41 ++- .../text_generation_server/utils/weights.py | 58 +++- 23 files changed, 736 insertions(+), 109 deletions(-) create mode 100644 server/Makefile-fbgemm create mode 100644 server/fbgemm_remove_unused.patch create mode 100755 server/fix_torch90a.sh diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index abe161db0ec..6c968053251 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -183,4 +183,3 @@ jobs: export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE pytest -s -vv integration-tests ${PYTEST_FLAGS} - diff --git a/Dockerfile b/Dockerfile index 3f2e8ef0154..54ddd5efd2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -161,6 +161,17 @@ COPY server/custom_kernels/ . # Build specific version of transformers RUN python setup.py build +# Build FBGEMM CUDA kernels +FROM kernel-builder AS fbgemm-builder + +WORKDIR /usr/src + +COPY server/Makefile-fbgemm Makefile +COPY server/fbgemm_remove_unused.patch fbgemm_remove_unused.patch +COPY server/fix_torch90a.sh fix_torch90a.sh + +RUN make build-fbgemm + # Build vllm CUDA kernels FROM kernel-builder AS vllm-builder @@ -225,10 +236,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy build artifacts from marlin kernels builder COPY --from=marlin-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages COPY --from=lorax-punica-builder /usr/src/lorax-punica/server/punica_kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - -# Copy builds artifacts from vllm builder +# Copy build artifacts from fbgemm builder +COPY --from=fbgemm-builder /usr/src/fbgemm/fbgemm_gpu/_skbuild/linux-x86_64-3.10/cmake-install /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages - # Copy build artifacts from mamba builder COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages diff --git a/server/Makefile b/server/Makefile index 33940655620..209fc44e484 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,6 +5,7 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica +include Makefile-fbgemm unit-tests: pytest -s -vv -m "not private" tests @@ -27,8 +28,9 @@ install-server: gen-server install: install-cuda echo "Installed server" -install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention +install-cuda: install-server install-flash-attention-v2-cuda install-vllm-cuda install-flash-attention install-fbgemm pip install -e ".[bnb]" + pip install nvidia-nccl-cu12==2.22.3 install-rocm: install-server install-flash-attention-v2-rocm install-vllm-rocm @@ -36,5 +38,6 @@ run-dev: SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded export-requirements: - poetry export -o requirements_cuda.txt --without-hashes --with cuda + poetry export -o requirements_cuda.txt --without-hashes poetry export -o requirements_rocm.txt --without-hashes + poetry export -o requirements_intel.txt --without-hashes diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm new file mode 100644 index 00000000000..38f8f31f79e --- /dev/null +++ b/server/Makefile-fbgemm @@ -0,0 +1,15 @@ +fbgemm_commit := 9cf0429b726931cfab72b8264730bea682f32fca + +build-fbgemm: + chmod +x fix_torch90a.sh && ./fix_torch90a.sh && \ + git clone https://github.com/pytorch/FBGEMM.git fbgemm && \ + cp fbgemm_remove_unused.patch fbgemm && \ + cd fbgemm && git fetch && git checkout $(fbgemm_commit) && git apply fbgemm_remove_unused.patch && \ + git submodule update --init --recursive && \ + cd fbgemm_gpu && \ + pip install -r requirements.txt && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build + +install-fbgemm: build-fbgemm + cd fbgemm/fbgemm_gpu && \ + CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/fbgemm_remove_unused.patch b/server/fbgemm_remove_unused.patch new file mode 100644 index 00000000000..ad6af81191a --- /dev/null +++ b/server/fbgemm_remove_unused.patch @@ -0,0 +1,306 @@ +diff --git a/fbgemm_gpu/CMakeLists.txt b/fbgemm_gpu/CMakeLists.txt +index 2244ea6f..96265a48 100644 +--- a/fbgemm_gpu/CMakeLists.txt ++++ b/fbgemm_gpu/CMakeLists.txt +@@ -94,14 +94,14 @@ endif() + # Build Experimental Modules + ################################################################################ + +-if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) +- # TODO: Figure out NCCL/RCCL integration with ROCm +- add_subdirectory(experimental/example) +-endif() +- +-if(NOT FBGEMM_CPU_ONLY) +- add_subdirectory(experimental/gemm) +-endif() ++# if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) ++# # TODO: Figure out NCCL/RCCL integration with ROCm ++# add_subdirectory(experimental/example) ++# endif() ++ ++# if(NOT FBGEMM_CPU_ONLY) ++# add_subdirectory(experimental/gemm) ++# endif() + + if(NOT FBGEMM_CPU_ONLY AND NOT USE_ROCM) + # CUTLASS currently doesn't build on ROCm and CK hasnt yet been added: +diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake +index c56773fe..0c0d349e 100644 +--- a/fbgemm_gpu/FbgemmGpu.cmake ++++ b/fbgemm_gpu/FbgemmGpu.cmake +@@ -446,53 +446,55 @@ set_source_files_properties(${fbgemm_sources} + ################################################################################ + + set(fbgemm_gpu_sources_static_cpu +- codegen/training/forward/embedding_forward_split_cpu.cpp +- codegen/inference/embedding_forward_quantized_host_cpu.cpp +- codegen/training/backward/embedding_backward_dense_host_cpu.cpp +- codegen/utils/embedding_bounds_check_host_cpu.cpp +- src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp +- src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp +- src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +- src/input_combine_ops/input_combine_cpu.cpp +- src/layout_transform_ops/layout_transform_ops_cpu.cpp ++ # codegen/training/forward/embedding_forward_split_cpu.cpp ++ # codegen/inference/embedding_forward_quantized_host_cpu.cpp ++ # codegen/training/backward/embedding_backward_dense_host_cpu.cpp ++ # codegen/utils/embedding_bounds_check_host_cpu.cpp ++ # src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_meta.cpp ++ # src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp ++ # src/input_combine_ops/input_combine_cpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_cpu.cpp + src/quantize_ops/quantize_ops_cpu.cpp + src/quantize_ops/quantize_ops_meta.cpp +- src/sparse_ops/sparse_ops_cpu.cpp +- src/sparse_ops/sparse_ops_meta.cpp +- src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp +- src/split_embeddings_cache/linearize_cache_indices.cpp +- src/split_embeddings_cache/lfu_cache_populate_byte.cpp +- src/split_embeddings_cache/lru_cache_populate_byte.cpp +- src/split_embeddings_cache/lxu_cache.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_ops.cpp +- codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++ # src/sparse_ops/sparse_ops_cpu.cpp ++ # src/sparse_ops/sparse_ops_meta.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_cpu.cpp ++ # src/split_embeddings_cache/linearize_cache_indices.cpp ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lru_cache_populate_byte.cpp ++ # src/split_embeddings_cache/lxu_cache.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_ops.cpp ++ # codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) ++) + + if(NOT FBGEMM_CPU_ONLY) + list(APPEND fbgemm_gpu_sources_static_cpu +- codegen/inference/embedding_forward_quantized_host.cpp +- codegen/utils/embedding_bounds_check_host.cpp +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp +- src/layout_transform_ops/layout_transform_ops_gpu.cpp +- src/memory_utils/memory_utils.cpp +- src/memory_utils/memory_utils_ops.cpp +- src/memory_utils/memory_utils_ops_cpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp ++ # codegen/inference/embedding_forward_quantized_host.cpp ++ # codegen/utils/embedding_bounds_check_host.cpp ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning_gpu.cpp ++ # src/layout_transform_ops/layout_transform_ops_gpu.cpp ++ # src/memory_utils/memory_utils.cpp ++ # src/memory_utils/memory_utils_ops.cpp ++ # src/memory_utils/memory_utils_ops_cpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_gpu.cpp ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_gpu.cpp + src/quantize_ops/quantize_ops_gpu.cpp +- src/sparse_ops/sparse_ops_gpu.cpp +- src/split_embeddings_utils/split_embeddings_utils.cpp +- src/split_embeddings_cache/split_embeddings_cache_ops.cu +- src/metric_ops/metric_ops_host.cpp +- src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp +- src/input_combine_ops/input_combine_gpu.cpp +- codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ # src/sparse_ops/sparse_ops_gpu.cpp ++ # src/split_embeddings_utils/split_embeddings_utils.cpp ++ # src/split_embeddings_cache/split_embeddings_cache_ops.cu ++ # src/metric_ops/metric_ops_host.cpp ++ # src/embedding_inplace_ops/embedding_inplace_update_gpu.cpp ++ # src/input_combine_ops/input_combine_gpu.cpp ++ # codegen/training/index_select/batch_index_select_dim0_host.cpp) ++ ) + + if(NVML_LIB_PATH OR USE_ROCM) + message(STATUS "Adding merge_pooled_embeddings sources") +@@ -516,36 +518,36 @@ endif() + + if(NOT FBGEMM_CPU_ONLY) + set(fbgemm_gpu_sources_static_gpu +- codegen/utils/embedding_bounds_check.cu +- codegen/inference/embedding_forward_quantized_split_lookup.cu +- src/embedding_inplace_ops/embedding_inplace_update.cu +- src/histogram_binning_calibration_ops.cu +- src/input_combine_ops/input_combine.cu +- src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu +- src/memory_utils/memory_utils.cu +- src/memory_utils/memory_utils_ops.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu +- src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu +- src/jagged_tensor_ops/dense_to_jagged_forward.cu +- src/jagged_tensor_ops/jagged_dense_bmm_forward.cu +- src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu +- src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu +- src/jagged_tensor_ops/jagged_index_add_2d_forward.cu +- src/jagged_tensor_ops/jagged_index_select_2d_forward.cu +- src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +- src/jagged_tensor_ops/jagged_softmax_backward.cu +- src/jagged_tensor_ops/jagged_softmax_forward.cu +- src/jagged_tensor_ops/jagged_tensor_ops.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu +- src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu +- src/jagged_tensor_ops/jagged_unique_indices.cu +- src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +- src/layout_transform_ops/layout_transform_ops.cu +- src/metric_ops/metric_ops.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu +- src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu +- src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu ++ # codegen/utils/embedding_bounds_check.cu ++ # codegen/inference/embedding_forward_quantized_split_lookup.cu ++ # src/embedding_inplace_ops/embedding_inplace_update.cu ++ # src/histogram_binning_calibration_ops.cu ++ # src/input_combine_ops/input_combine.cu ++ # src/intraining_embedding_pruning_ops/intraining_embedding_pruning.cu ++ # src/memory_utils/memory_utils.cu ++ # src/memory_utils/memory_utils_ops.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu ++ # src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu ++ # src/jagged_tensor_ops/dense_to_jagged_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu ++ # src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu ++ # src/jagged_tensor_ops/jagged_index_add_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_index_select_2d_forward.cu ++ # src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu ++ # src/jagged_tensor_ops/jagged_softmax_backward.cu ++ # src/jagged_tensor_ops/jagged_softmax_forward.cu ++ # src/jagged_tensor_ops/jagged_tensor_ops.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu ++ # src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu ++ # src/jagged_tensor_ops/jagged_unique_indices.cu ++ # src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu ++ # src/layout_transform_ops/layout_transform_ops.cu ++ # src/metric_ops/metric_ops.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu ++ # src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu ++ # src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu + src/quantize_ops/quantize_bfloat16.cu + src/quantize_ops/quantize_fp8_rowwise.cu + src/quantize_ops/quantize_fused_8bit_rowwise.cu +@@ -554,39 +556,40 @@ if(NOT FBGEMM_CPU_ONLY) + src/quantize_ops/quantize_msfp.cu + src/quantize_ops/quantize_padded_fp8_rowwise.cu + src/quantize_ops/quantize_mx.cu +- src/sparse_ops/sparse_async_cumsum.cu +- src/sparse_ops/sparse_block_bucketize_features.cu +- src/sparse_ops/sparse_bucketize_features.cu +- src/sparse_ops/sparse_batched_unary_embeddings.cu +- src/sparse_ops/sparse_compute_frequency_sequence.cu +- src/sparse_ops/sparse_expand_into_jagged_permute.cu +- src/sparse_ops/sparse_group_index.cu +- src/sparse_ops/sparse_index_add.cu +- src/sparse_ops/sparse_index_select.cu +- src/sparse_ops/sparse_invert_permute.cu +- src/sparse_ops/sparse_pack_segments_backward.cu +- src/sparse_ops/sparse_pack_segments_forward.cu +- src/sparse_ops/sparse_permute_1d.cu +- src/sparse_ops/sparse_permute_2d.cu +- src/sparse_ops/sparse_permute102.cu +- src/sparse_ops/sparse_permute_embeddings.cu +- src/sparse_ops/sparse_range.cu +- src/sparse_ops/sparse_reorder_batched_ad.cu +- src/sparse_ops/sparse_segment_sum_csr.cu +- src/sparse_ops/sparse_zipf.cu +- src/split_embeddings_cache/lfu_cache_find.cu +- src/split_embeddings_cache/lfu_cache_populate.cu +- src/split_embeddings_cache/lfu_cache_populate_byte.cu +- src/split_embeddings_cache/lru_cache_find.cu +- src/split_embeddings_cache/lru_cache_populate.cu +- src/split_embeddings_cache/lru_cache_populate_byte.cu +- src/split_embeddings_cache/lxu_cache.cu +- src/split_embeddings_cache/linearize_cache_indices.cu +- src/split_embeddings_cache/reset_weight_momentum.cu +- src/split_embeddings_utils/generate_vbe_metadata.cu +- src/split_embeddings_utils/get_infos_metadata.cu +- src/split_embeddings_utils/radix_sort_pairs.cu +- src/split_embeddings_utils/transpose_embedding_input.cu) ++ # src/sparse_ops/sparse_async_cumsum.cu ++ # src/sparse_ops/sparse_block_bucketize_features.cu ++ # src/sparse_ops/sparse_bucketize_features.cu ++ # src/sparse_ops/sparse_batched_unary_embeddings.cu ++ # src/sparse_ops/sparse_compute_frequency_sequence.cu ++ # src/sparse_ops/sparse_expand_into_jagged_permute.cu ++ # src/sparse_ops/sparse_group_index.cu ++ # src/sparse_ops/sparse_index_add.cu ++ # src/sparse_ops/sparse_index_select.cu ++ # src/sparse_ops/sparse_invert_permute.cu ++ # src/sparse_ops/sparse_pack_segments_backward.cu ++ # src/sparse_ops/sparse_pack_segments_forward.cu ++ # src/sparse_ops/sparse_permute_1d.cu ++ # src/sparse_ops/sparse_permute_2d.cu ++ # src/sparse_ops/sparse_permute102.cu ++ # src/sparse_ops/sparse_permute_embeddings.cu ++ # src/sparse_ops/sparse_range.cu ++ # src/sparse_ops/sparse_reorder_batched_ad.cu ++ # src/sparse_ops/sparse_segment_sum_csr.cu ++ # src/sparse_ops/sparse_zipf.cu ++ # src/split_embeddings_cache/lfu_cache_find.cu ++ # src/split_embeddings_cache/lfu_cache_populate.cu ++ # src/split_embeddings_cache/lfu_cache_populate_byte.cu ++ # src/split_embeddings_cache/lru_cache_find.cu ++ # src/split_embeddings_cache/lru_cache_populate.cu ++ # src/split_embeddings_cache/lru_cache_populate_byte.cu ++ # src/split_embeddings_cache/lxu_cache.cu ++ # src/split_embeddings_cache/linearize_cache_indices.cu ++ # src/split_embeddings_cache/reset_weight_momentum.cu ++ # src/split_embeddings_utils/generate_vbe_metadata.cu ++ # src/split_embeddings_utils/get_infos_metadata.cu ++ # src/split_embeddings_utils/radix_sort_pairs.cu ++ # src/split_embeddings_utils/transpose_embedding_input.cu) ++ ) + + set_source_files_properties(${fbgemm_gpu_sources_static_gpu} + PROPERTIES COMPILE_OPTIONS +diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +index 01f1d6ab..a6b8d7a8 100644 +--- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt ++++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +@@ -25,23 +25,24 @@ set(fbgemm_sources_include_directories + ${THIRDPARTY}/json/include + ${NCCL_INCLUDE_DIRS}) + +-set(attention_ops_sources +- src/attention/attention.cpp +- src/attention/gqa_attn_splitk.cu) ++# set(attention_ops_sources ++# src/attention/attention.cpp ++# src/attention/gqa_attn_splitk.cu) + + set(quantize_ops_sources + src/quantize/cutlass_extensions.cu + src/quantize/quantize.cu + src/quantize/quantize.cpp) + +-set(comm_ops_sources +- src/comm/car.cu +- src/comm/car.cpp) ++# set(comm_ops_sources ++# src/comm/car.cu ++# src/comm/car.cpp) + + set(experimental_gen_ai_cpp_source_files +- ${attention_ops_sources} ++ # ${attention_ops_sources} + ${quantize_ops_sources} +- ${comm_ops_sources}) ++ # ${comm_ops_sources} ++) + + set_source_files_properties(${experimental_gen_ai_cpp_source_files} + PROPERTIES INCLUDE_DIRECTORIES diff --git a/server/fix_torch90a.sh b/server/fix_torch90a.sh new file mode 100755 index 00000000000..5e444828217 --- /dev/null +++ b/server/fix_torch90a.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# This script is required to patch torch < 2.4 +# It adds the 90a cuda target (H100) +# This target is required to build FBGEMM kernels + +torch_cuda_arch=$(python -c "import torch; print(torch.__file__)" | sed 's/\/__init__.py//; s|$|/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake|') + +sed -i '189s/\[0-9]\\\\\.\[0-9](/[0-9]\\\\.[0-9]a?(/' $torch_cuda_arch +sed -i '245s/\[0-9()]+\+"/[0-9()]+a?"/' $torch_cuda_arch +sed -i '246s/\[0-9]+\+"/[0-9]+a?"/' $torch_cuda_arch diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index fe839cf4f9b..8ec2a5ae66a 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -8,6 +8,7 @@ from enum import Enum from huggingface_hub import hf_hub_download +from text_generation_server.utils.log import log_master app = typer.Typer() @@ -87,15 +88,17 @@ def serve( ) if len(lora_adapter_ids) > 0: - logger.warning( - f"LoRA adapters are enabled. This is an experimental feature and may not work as expected." + log_master( + logger.warning, + f"LoRA adapters are enabled. This is an experimental feature and may not work as expected.", ) # TODO: enable lora with cuda graphs. for now disable cuda graphs if lora is enabled # and warn the user if len(lora_adapter_ids) > 0 and os.getenv("CUDA_GRAPHS", None) is not None: - logger.warning( - f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs." + log_master( + logger.warning, + f"LoRa adapter are not supported with CUDA Graphs. Disabling CUDA Graphs.", ) global CUDA_GRAPHS CUDA_GRAPHS = None diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 99c490d5f50..54da63e855c 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -3,6 +3,7 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers.attention import Seqlen +from text_generation_server.utils.log import log_master from loguru import logger major, minor = torch.cuda.get_device_capability() @@ -136,7 +137,10 @@ def paged_attention( try: import flash_attn_2_cuda - logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + log_master( + logger.info, + "ROCm: using Flash Attention 2 Composable Kernel implementation.", + ) except ImportError as e: if major >= 8: architecture_suffix = f"-{SYSTEM}" diff --git a/server/text_generation_server/layers/bnb.py b/server/text_generation_server/layers/bnb.py index 925b0b2d6c8..aae2bd1a243 100644 --- a/server/text_generation_server/layers/bnb.py +++ b/server/text_generation_server/layers/bnb.py @@ -4,19 +4,11 @@ import bitsandbytes as bnb import torch from bitsandbytes.nn import Int8Params, Params4bit -from loguru import logger -from text_generation_server.utils.weights import Weight - - -@lru_cache(1) -def warn_deprecate_bnb(): - logger.warning( - "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" - ) +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class BNBWeight(Weight): +class BNBWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -82,7 +74,7 @@ def forward(self, x: torch.Tensor): @dataclass -class BNBFP4Weight(Weight): +class BNBFP4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -90,7 +82,7 @@ def get_linear(self, bias: torch.Tensor): @dataclass -class BNBNF4Weight(Weight): +class BNBNF4Weight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/eetq.py b/server/text_generation_server/layers/eetq.py index f003f91452d..b1e5235a030 100644 --- a/server/text_generation_server/layers/eetq.py +++ b/server/text_generation_server/layers/eetq.py @@ -2,11 +2,11 @@ import torch from EETQ import quant_weights, w8_a16_gemm -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import UnquantizedWeight @dataclass -class EETQWeight(Weight): +class EETQWeight(UnquantizedWeight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index b56f568a0c9..cdf16d6b5b2 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,8 +1,29 @@ +import torch + from dataclasses import dataclass +from typing import Optional, Union, List +from loguru import logger -import torch from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.weights import Weight +from text_generation_server.utils.weights import ( + Weight, + WeightsLoader, + UnquantizedWeight, + Weights, +) +from text_generation_server.utils.log import log_master, log_once + +FBGEMM_MM_AVAILABLE = False +FBGEMM_DYN_AVAILABLE = False +try: + import fbgemm_gpu.experimental.gen_ai + + if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() + FBGEMM_MM_AVAILABLE = major == 9 + FBGEMM_DYN_AVAILABLE = major >= 8 +except (ImportError, ModuleNotFoundError): + log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") def get_fp8_linear() -> torch.nn.Module: @@ -21,12 +42,17 @@ def get_fp8_linear() -> torch.nn.Module: return Fp8Linear -def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): - device = weight.device +def fp8_quantize(weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn): + if FBGEMM_DYN_AVAILABLE: + qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( + weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype + ) + return qweight, scale + # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) # Calculate the scale as dtype max divided by absmax - scale = finfo.max / weight.abs().max().clamp(min=1e-12) + scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) # scale and clamp the tensor to bring it to # the representative range of float8 data type # (as default cast is unsaturated) @@ -38,27 +64,166 @@ def fp8_quantize(weight, qdtype=torch.float8_e4m3fn): return qweight, scale +class HybridFP8UnquantLoader(WeightsLoader): + """Weight loader that loads FP8 and unquantized Torch tensors.""" + + def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): + self.activation_scale_ub = activation_scale_ub + self.to_fp8 = to_fp8 + + def get_weights(self, weights: "Weights", prefix: str): + w = weights.get_tensor(f"{prefix}.weight") + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_col_packed( + self, + weights: Weights, + prefix: str, + block_sizes: Union[int, List[int]], + ): + w = weights.get_packed_sharded( + f"{prefix}.weight", dim=0, block_sizes=block_sizes + ) + + if w.dtype == torch.float8_e4m3fn: + # FP8 branch + scale = weights.get_packed_sharded( + f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False + ) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): + w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] + w = torch.cat(w, dim=dim) + + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = [ + weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) + for p in prefixes + ] + scale = torch.cat(scale, dim=0) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + def get_weights_row(self, weights: "Weights", prefix: str): + w = weights.get_sharded(f"{prefix}.weight", dim=1) + # FP8 branch + if w.dtype == torch.float8_e4m3fn: + scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) + return Fp8Weight( + weight=w, + weight_scale=scale, + activation_scale_ub=self.activation_scale_ub, + dtype=weights.dtype, + ) + if self.to_fp8: + return Fp8Weight(weight=w, dtype=weights.dtype) + + return UnquantizedWeight(w) + + @dataclass class Fp8Weight(Weight): weight: torch.Tensor + dtype: torch.dtype + weight_scale: Optional[torch.Tensor] = None + activation_scale_ub: Optional[float] = None def get_linear(self, bias: torch.Tensor): - return get_fp8_linear()(self.weight, bias) + if self.weight_scale is None: + return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) + return get_fp8_linear().from_fp8( + self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype + ) class Fp8Linear(torch.nn.Module): def __init__( self, - weight, + qweight, + scale, + scale_upper_bound, bias, + dtype, ) -> None: super().__init__() - self.dtype = weight.dtype - self.qweight, self.scale = fp8_quantize(weight) + self.dtype = dtype + self.qweight = qweight + self.scale = scale + self.scale_upper_bound = ( + torch.tensor( + [scale_upper_bound], dtype=torch.float32, device=qweight.device + ) + if scale_upper_bound is not None + else None + ) self.bias = bias if bias is not None else None + @classmethod + def from_unquant(cls, weight, bias, dtype): + qweight, scale = fp8_quantize(weight) + return cls( + qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype + ) + + @classmethod + def from_fp8(cls, weight, scale, input_scale, bias, dtype): + return cls( + qweight=weight, + scale=scale, + scale_upper_bound=input_scale, + bias=bias, + dtype=dtype, + ) + def forward(self, input: torch.Tensor) -> torch.Tensor: + if FBGEMM_MM_AVAILABLE: + qinput, scale = fp8_quantize( + input, scale_upper_bound=self.scale_upper_bound + ) + + y = torch.ops.fbgemm.f8f8bf16_rowwise( + qinput, + self.qweight, + scale, + self.scale, + use_fast_accum=True, + bias=self.bias, + ) + return y.to(self.dtype) + qinput, scale = fp8_quantize(input) output, _ = torch._scaled_mm( qinput, diff --git a/server/text_generation_server/layers/gptq/exllamav2.py b/server/text_generation_server/layers/gptq/exllamav2.py index 4d45822bed2..dc3b832f9df 100644 --- a/server/text_generation_server/layers/gptq/exllamav2.py +++ b/server/text_generation_server/layers/gptq/exllamav2.py @@ -9,11 +9,12 @@ from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.gptq import GPTQWeight +from text_generation_server.utils.log import log_master try: from exllamav2_kernels import make_q_matrix, gemm_half_q_half except ImportError: - logger.error("exllamav2_kernels not installed.") + log_master(logger.warning, "exllamav2_kernels not installed.") raise # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/server/text_generation_server/layers/marlin.py b/server/text_generation_server/layers/marlin.py index a913ff57a2d..40271c35c00 100644 --- a/server/text_generation_server/layers/marlin.py +++ b/server/text_generation_server/layers/marlin.py @@ -503,7 +503,8 @@ class GPTQMarlinFP8Linear(nn.Module): def __init__( self, - weight: torch.Tensor, + qweight: torch.Tensor, + scale: torch.Tensor, bias: Optional[torch.Tensor], ) -> None: super().__init__() @@ -513,7 +514,6 @@ def __init__( log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") - qweight, scale = fp8_quantize(weight) scale = scale.to(torch.float16) qweight, scales = repack_fp8_for_marlin(qweight, scale) @@ -529,6 +529,15 @@ def __init__( out_features // 64 * 16, dtype=torch.int, device=qweight.device ) + @classmethod + def from_unquant(cls, weight, bias, _dtype): + qweight, scale = fp8_quantize(weight) + return cls(qweight=qweight, scale=scale, bias=bias) + + @classmethod + def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): + return cls(qweight=weight, scale=scale, bias=bias) + def forward(self, A: torch.Tensor) -> torch.Tensor: assert marlin_kernels is not None diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 690a8887675..a43cdfed6ae 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -34,6 +34,7 @@ ) from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.log import log_master # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -47,9 +48,7 @@ __all__ = [ "Model", - "BLOOMSharded", "CausalLM", - "GalacticaSharded", "Seq2SeqLM", "get_model", ] @@ -125,7 +124,7 @@ ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") + log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}") SUPPORTS_WINDOWING = False FLASH_ATTENTION = False @@ -137,7 +136,7 @@ try: from text_generation_server.models.mamba import Mamba except ImportError as e: - logger.warning(f"Could not import Mamba: {e}") + log_master(logger.warning, f"Could not import Mamba: {e}") MAMBA_AVAILABLE = False if MAMBA_AVAILABLE: @@ -311,6 +310,12 @@ def get_model( if quantize in ["awq", "exl2", "gptq", "marlin"]: # These quantizers only work with float16 params. dtype = torch.float16 + elif quantize == "fp8": + from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE + + if FBGEMM_MM_AVAILABLE: + # fbgemm kernels are fp8xfp8->bf16 + dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype. @@ -433,7 +438,9 @@ def get_model( speculate = get_speculate() if speculate > 0: - logger.info(f"Using speculation {method} with {speculate} input ids.") + log_master( + logger.info, f"Using speculation {method} with {speculate} input ids." + ) if model_type is None: # TODO: fix how we determine model type for Mamba @@ -448,10 +455,10 @@ def get_model( if quantization_config is not None and quantize is None: method = quantization_config.get("quant_method", None) if method in {"gptq", "awq", "exl2"}: - logger.info(f"Auto selecting quantization method {method}") + log_master(logger.info, f"Auto selecting quantization method {method}") quantize = method else: - logger.info(f"Unknown quantization method {method}") + log_master(logger.warning, f"Unknown quantization method {method}") if quantize == "exl2" and sharded: raise RuntimeError( @@ -593,7 +600,7 @@ def get_model( ) except RuntimeError as e: # Lots of legacy models with various weight names. - logger.warning(f"Couldn't load flash gpt2 variant: {e}") + log_master(logger.warning, f"Couldn't load flash gpt2 variant: {e}") return CausalLM.fallback( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 5237a484941..f7980d2d172 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -33,7 +33,6 @@ attention, reshape_and_cache, ) -from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, @@ -42,16 +41,15 @@ TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) -from text_generation_server.layers.fp8 import Fp8Weight from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import ( - DefaultWeightsLoader, UnquantizedWeight, Weights, ) +from text_generation_server.layers.fp8 import HybridFP8UnquantLoader if SYSTEM == "rocm": try: @@ -113,12 +111,12 @@ def load_attention(config, prefix: str, weights, layer_id): @contextmanager def no_fp8(weights: Weights): + """De-activate fp8 auto conversion for the duration of this context manager""" weights_loader = weights.weights_loader - if ( - isinstance(weights_loader, DefaultWeightsLoader) - and weights_loader.weight_class is Fp8Weight - ): - weights_loader = DefaultWeightsLoader(UnquantizedWeight) + if isinstance(weights_loader, HybridFP8UnquantLoader) and weights_loader.to_fp8: + weights_loader = HybridFP8UnquantLoader( + weights_loader.activation_scale_ub, to_fp8=False + ) with weights.use_loader(weights_loader): yield @@ -418,7 +416,22 @@ def __init__(self, prefix, config, weights): process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() - self.layers = nn.ModuleList( + + # Skip fp8 quant for first and last layers + self.layers = nn.ModuleList() + with no_fp8(weights): + self.layers.append( + FlashLlamaLayer( + index=0, + prefix=( + "model.layers.0" if not prefix else "{prefix}.model.layers.0" + ), + config=config, + weights=weights, + ) + ) + + self.layers.extend( [ FlashLlamaLayer( index=layer_id, @@ -430,9 +443,26 @@ def __init__(self, prefix, config, weights): config=config, weights=weights, ) - for layer_id in range(config.num_hidden_layers) + # Skip first and last layers + for layer_id in range(1, config.num_hidden_layers - 1) ] ) + + with no_fp8(weights): + last_layer_id = config.num_hidden_layers - 1 + self.layers.append( + FlashLlamaLayer( + index=last_layer_id, + prefix=( + f"model.layers.{last_layer_id}" + if not prefix + else f"{prefix}.model.layers.{last_layer_id}" + ), + config=config, + weights=weights, + ) + ) + self.norm = FastRMSNorm.load( prefix="model.norm" if not prefix else f"{prefix}.model.norm", weights=weights, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2888f1f79e3..cfffafa1b7a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -23,14 +23,13 @@ from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model +from text_generation_server.utils.log import log_master from text_generation_server.utils.tokens import batch_top_tokens -from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils import ( initialize_torch_distributed, weight_files, Weights, - hub, ) from text_generation_server.models.types import ( Batch, @@ -1156,31 +1155,36 @@ def warmup(self, batch: FlashCausalLMBatch): f"tunableop_{MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv", ) - logger.info( - f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`." + log_master( + logger.info, + f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`.", ) if os.path.isfile(tunableop_filepath): - logger.info( - f"The file {tunableop_filepath} already exists and will be reused." + log_master( + logger.info, + f"The file {tunableop_filepath} already exists and will be reused.", ) torch.cuda.tunable.read_file(tunableop_filepath) os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True) for seqlen in tuning_sequences: - logger.info(f"Warming up TunableOp for seqlen={seqlen}") + log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") self.tunableop_warmup(seqlen) torch.cuda.tunable.write_file(tunableop_filepath) torch.cuda.tunable.tuning_enable(False) else: - logger.info( - "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp." + log_master( + logger.info, + "PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp.", ) if CUDA_GRAPHS: try: - logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") + log_master( + logger.info, f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}" + ) # Warmup cuda graphs for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: @@ -1188,7 +1192,9 @@ def warmup(self, batch: FlashCausalLMBatch): except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") else: - logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") + log_master( + logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." + ) return int(num_blocks * BLOCK_SIZE) @@ -1540,8 +1546,7 @@ def generate_token( left = 0 if n_accepted_ids > 1: - if RANK == 0: - logger.debug(f"Speculated ids {n_accepted_ids - 1}") + log_master(logger.debug, f"Speculated ids {n_accepted_ids - 1}") current_stopped = False for j in range(index, index + n_accepted_ids): diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 06035ccdd99..ac42df30c04 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,15 +1,16 @@ import torch import os from loguru import logger -from typing import Dict +from typing import Dict, Optional + +from text_generation_server.utils.log import log_master MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"} BLOCK_SIZE: int = 256 if FLASH_DECODING else 16 if FLASH_DECODING: - logger.info("Using FLASH_DECODING") - + log_master(logger.info, "Using FLASH_DECODING") cuda_graphs = os.getenv("CUDA_GRAPHS") if cuda_graphs is not None: @@ -26,11 +27,9 @@ if cuda_graphs is not None: cuda_graphs.sort(reverse=True) - CUDA_GRAPHS = cuda_graphs # This is overridden at model loading. -global MODEL_ID MODEL_ID = None @@ -41,8 +40,7 @@ def set_model_id(model_id: str): # NOTE: eventually we should move this into the router and pass back the # index in all cases. -global ADAPTER_TO_INDEX -ADAPTER_TO_INDEX: Dict[str, int] = None +ADAPTER_TO_INDEX: Optional[Dict[str, int]] = None def set_adapter_to_index(adapter_to_index: Dict[str, int]): diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 09130b85990..e7748bb99f6 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -15,6 +15,7 @@ AdapterParameters, AdapterSource, ) +from text_generation_server.utils.log import log_master from loguru import logger @@ -204,8 +205,9 @@ def load_adapter( f"order to use the dynamic adapter loading feature." ) - logger.info( - f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}" + log_master( + logger.info, + f"Loading adapter weights into model: {','.join(adapter_parameters.adapter_ids)}", ) weight_names = tuple([v[0] for v in self.target_to_layer.values()]) ( @@ -240,8 +242,9 @@ def load_adapter( layer_weights.add_adapter(adapter_index, adapter_weights) if len(unused_weight_names) > 0: - logger.warning( - f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}" + log_master( + logger.warning, + f"{','.join(adapter_parameters.adapter_ids)} unused adapter weights: {unused_weight_names}", ) if adapter_tokenizer is not None: diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index f869f8b53ea..308d5a3d963 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -1,4 +1,3 @@ -from itertools import repeat import torch from PIL import Image from io import BytesIO @@ -13,6 +12,7 @@ FlashCausalLMBatch, FlashCausalLM, ) +from text_generation_server.utils.log import log_master from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -56,8 +56,9 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str num_features = get_number_of_features(height, width, config) from loguru import logger - logger.info( - f"Found {num_features} features in image of resolution {height}x{width}" + log_master( + logger.info, + f"Found {num_features} features in image of resolution {height}x{width}", ) return "" * num_features diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 36d63e86d2c..82aeba6ce9f 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -56,7 +56,7 @@ def initialize_torch_distributed(): backend = "nccl" options = ProcessGroupNCCL.Options() options.is_high_priority_stream = True - options._timeout = timedelta(seconds=60) + options._timeout = timedelta(seconds=120) else: backend = "gloo" options = None @@ -76,7 +76,7 @@ def initialize_torch_distributed(): backend="ccl", world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: @@ -84,7 +84,7 @@ def initialize_torch_distributed(): backend=backend, world_size=WORLD_SIZE, rank=RANK, - timeout=timedelta(seconds=60), + timeout=timedelta(seconds=120), pg_options=options, ) else: diff --git a/server/text_generation_server/utils/log.py b/server/text_generation_server/utils/log.py index b1456f1e551..4385c71ee96 100644 --- a/server/text_generation_server/utils/log.py +++ b/server/text_generation_server/utils/log.py @@ -1,6 +1,15 @@ from functools import lru_cache +from text_generation_server.utils.dist import RANK @lru_cache(10) -def log_once(log, msg: str): - log(msg) +def log_once(log, msg: str, master=True): + if master: + log_master(log, msg) + else: + log(msg) + + +def log_master(log, msg: str): + if RANK == 0: + log(msg) diff --git a/server/text_generation_server/utils/quantization.py b/server/text_generation_server/utils/quantization.py index e8e22db8d33..c3c038fe51a 100644 --- a/server/text_generation_server/utils/quantization.py +++ b/server/text_generation_server/utils/quantization.py @@ -11,6 +11,7 @@ ) +# TODO: Split this config to have a single config type per quant method @dataclass class _QuantizerConfig: bits: int @@ -21,6 +22,11 @@ class _QuantizerConfig: sym: bool +@dataclass +class _FP8QuantizerConfig: + activation_scale_ub: float + + # We should probably do this with Pytantic JSON deserialization, # but for now we'll stay close to the old _set_gptq_params. def _get_quantizer_config(model_id, revision): @@ -39,6 +45,13 @@ def _get_quantizer_config(model_id, revision): filename = hf_hub_download(model_id, filename=filename, revision=revision) with open(filename, "r") as f: data = json.load(f) + + # FP8 config + if data["quantization_config"]["quant_method"] == "fbgemm_fp8": + return _FP8QuantizerConfig( + activation_scale_ub=data["quantization_config"]["activation_scale_ub"] + ) + bits = data["quantization_config"]["bits"] groupsize = data["quantization_config"]["group_size"] # Order is important here, desc_act is missing on some real models @@ -99,6 +112,12 @@ def get_loader( if quantize in {"awq", "gptq"}: from text_generation_server.layers.gptq import GPTQWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return GPTQWeightsLoader( bits=quantizer_config.bits, desc_act=quantizer_config.desc_act, @@ -127,18 +146,28 @@ def get_loader( from text_generation_server.layers.exl2 import Exl2WeightsLoader return Exl2WeightsLoader() - elif quantize == "fp8": - from text_generation_server.layers.fp8 import Fp8Weight - - return DefaultWeightsLoader(Fp8Weight) elif quantize == "marlin": from text_generation_server.layers.marlin import MarlinWeightsLoader + # TODO: improve check once we have one config type per quantize value + if not isinstance(quantizer_config, _QuantizerConfig): + raise ValueError( + f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." + ) + return MarlinWeightsLoader( bits=quantizer_config.bits, is_marlin_24=quantizer_config.checkpoint_format == "marlin_24", ) - elif quantize is None: - return DefaultWeightsLoader(UnquantizedWeight) + elif quantize == "fp8" or quantize is None: + from text_generation_server.layers.fp8 import HybridFP8UnquantLoader + + # Since the default for the quantize config is _QuantizerConfig, + # we need to add this check to not get an attribute error + activation_scale_ub = None + if isinstance(quantizer_config, _FP8QuantizerConfig): + activation_scale_ub = quantizer_config.activation_scale_ub + + return HybridFP8UnquantLoader(activation_scale_ub, to_fp8=quantize == "fp8") else: raise ValueError(f"Unknown quantization method: {quantize}") diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 91592df090b..66bb6051da0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -1,12 +1,12 @@ +import torch + from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass -from enum import Enum, auto from pathlib import Path -from typing import Dict, List, Optional, Union - -import torch +from typing import Dict, List, Optional, Union, Type from safetensors import safe_open +from dataclasses import dataclass + from text_generation_server.utils.import_utils import SYSTEM @@ -84,7 +84,7 @@ def get_linear(self, bias: torch.Tensor): @dataclass -class UnquantizedWeight: +class UnquantizedWeight(Weight): weight: torch.Tensor def get_linear(self, bias: torch.Tensor): @@ -99,7 +99,7 @@ def get_linear(self, bias: torch.Tensor): class DefaultWeightsLoader(WeightsLoader): """Weight loader that loads (unquantized) Torch tensors.""" - def __init__(self, weight_class): + def __init__(self, weight_class: Type[UnquantizedWeight]): """Create a loader. Weights will be wrapped using the given `weights_class`, normally this will be `UnquantizedWeight`, but a quantizer-specific class such as `Fp8Weight` can be used to quantize the weights during loading. @@ -208,20 +208,29 @@ def _has_tensor(self, tensor_name: str): def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device=True): + def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) # Special case for gptq which shouldn't convert # u4 which are disguised as int32. Exl2 uses int16 - # as well. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + # as well. FP8 uses torch.float8_e4m3fn + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) if to_device: tensor = tensor.to(device=self.device) return tensor - def get_partial_sharded(self, tensor_name: str, dim: int): + def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -241,12 +250,16 @@ def get_partial_sharded(self, tensor_name: str, dim: int): raise NotImplementedError("Let's make that generic when needed") # Special case for gptq which shouldn't convert # u4 which are disguised as int32. exl2 uses int16. - if tensor.dtype not in (torch.int16, torch.int32): + # FP8 uses torch.float8_e4m3fn. + if ( + tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) return tensor - def get_sharded(self, tensor_name: str, dim: int): + def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) slice_ = f.get_slice(tensor_name) @@ -255,10 +268,14 @@ def get_sharded(self, tensor_name: str, dim: int): assert ( size % world_size == 0 ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" - return self.get_partial_sharded(tensor_name, dim) + return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) def get_packed_sharded( - self, tensor_name: str, dim: int, block_sizes: Union[int, List[int]] + self, + tensor_name: str, + dim: int, + block_sizes: Union[int, List[int]], + to_dtype=True, ) -> torch.Tensor: """ Get a shard from a tensor that packs multiple tensors. @@ -304,7 +321,16 @@ def get_packed_sharded( tensor = tensor.to(device=self.device) # Avoid casting quantizer dtypes. - if tensor.dtype not in [torch.int16, torch.int32, torch.int64]: + if ( + tensor.dtype + not in [ + torch.float8_e4m3fn, + torch.int16, + torch.int32, + torch.int64, + ] + and to_dtype + ): tensor = tensor.to(dtype=self.dtype) return tensor From f3435bab8c047d7e6fb8c03f1bf63f6f0702e3e1 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Sun, 21 Jul 2024 16:48:04 +0000 Subject: [PATCH 14/14] fix(server): fix deepseekv2 loading (#2266) --- .../models/custom_modeling/flash_deepseek_v2_modeling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 8ffbd143c02..f5b2ba0eb68 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -34,7 +34,6 @@ from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import Weights from torch import nn from transformers.activations import ACT2FN @@ -240,7 +239,6 @@ def __init__( if config.attention_bias else None ), - quantize=config.quantize, ) self.q_a_layernorm = FastRMSNorm.load( prefix=f"{prefix}.q_a_layernorm", @@ -261,7 +259,6 @@ def __init__( if config.attention_bias else None ), - quantize=config.quantize, ) self.kv_a_layernorm = FastRMSNorm.load(