Skip to content

Commit

Permalink
More tensor cores. (#2558)
Browse files Browse the repository at this point in the history
* More tensor cores.

* Fixing the logic.

* Gemma is modified by this.
  • Loading branch information
Narsil authored Sep 24, 2024
1 parent c032280 commit dd8691b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
"tokens": [
{
"id": 1736,
"logprob": -2.03125,
"logprob": -2.109375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"logprob": -1.90625,
"special": false,
"text": "\n\n"
},
Expand All @@ -42,48 +42,48 @@
},
{
"id": 2121,
"logprob": -1.8125,
"logprob": -1.796875,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.24121094,
"logprob": -0.24511719,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.100097656,
"logprob": -0.09326172,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9453125,
"logprob": -0.95703125,
"special": false,
"text": " is"
},
{
"id": 476,
"logprob": -1.703125,
"id": 1671,
"logprob": -1.5859375,
"special": false,
"text": " a"
"text": " used"
},
{
"id": 4551,
"logprob": -2.453125,
"id": 577,
"logprob": -0.39257812,
"special": false,
"text": " document"
"text": " to"
},
{
"id": 674,
"logprob": -0.796875,
"id": 3853,
"logprob": -1.25,
"special": false,
"text": " that"
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is a document that"
"generated_text": " form\n\nThe test request form is used to request"
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@
},
{
"id": 2015,
"logprob": -9.640625,
"logprob": -9.6484375,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.375,
"logprob": -10.3671875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 604,
"logprob": -0.2824707,
"logprob": -0.28271484,
"special": false,
"text": " for"
},
{
"id": 573,
"logprob": -0.19030762,
"logprob": -0.18493652,
"special": false,
"text": " the"
},
{
"id": 16819,
"logprob": -1.4892578,
"logprob": -1.4804688,
"special": false,
"text": " detection"
},
Expand All @@ -47,43 +47,43 @@
"text": " of"
},
{
"id": 573,
"logprob": -2.0195312,
"id": 671,
"logprob": -2.1738281,
"special": false,
"text": " the"
"text": " an"
},
{
"id": 8566,
"logprob": 0.0,
"id": 24646,
"logprob": -3.0449219,
"special": false,
"text": " presence"
"text": " RNA"
},
{
"id": 689,
"logprob": -0.16491699,
"id": 12369,
"logprob": -0.19299316,
"special": false,
"text": " or"
"text": " virus"
},
{
"id": 14862,
"logprob": 0.0,
"id": 575,
"logprob": -0.10632324,
"special": false,
"text": " absence"
"text": " in"
},
{
"id": 576,
"logprob": -0.9946289,
"id": 6022,
"logprob": -0.98095703,
"special": false,
"text": " of"
"text": " patients"
},
{
"id": 671,
"logprob": -0.5263672,
"id": 1064,
"logprob": -1.3095703,
"special": false,
"text": " an"
"text": " who"
}
],
"top_tokens": null
},
"generated_text": "Test request for the detection of the presence or absence of an"
"generated_text": "Test request for the detection of an RNA virus in patients who"
}
8 changes: 6 additions & 2 deletions server/text_generation_server/layers/attention/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,13 @@ def create_decode_state(
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)


Expand All @@ -175,14 +177,16 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
use_tensor_cores=num_heads // num_kv_heads > 4,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)


Expand Down

0 comments on commit dd8691b

Please sign in to comment.