Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model #4799

Merged
merged 101 commits into from
May 25, 2024

Conversation

linxihui
Copy link
Contributor

@linxihui linxihui commented May 14, 2024

  • Supports of Microsoft Phi-3-Small-8K and Phi-3-Small-128K models, which use blocksparse flash attention
  • Prefilling Triton kernel for block-sparse attn
  • Modified paged attention CUDA with the block-sparse attention, which allows hybrid sparsity pattern for each attention head.
  • Use torch SPDA in prefilling phase for V100 or older GPUs, as well as CPU

This is joint work between Microsoft GenAI @linxihui, @beagleski, and vLLM @zhuohan123, @simon-mo @youkaichao.

@simon-mo
Copy link
Collaborator

I made a pass. I think once this PR adds unit test for both the Triton and PagedAttention kernels it should be good to go. You might also need to run clang-format to fix the merge conflict.

@linxihui
Copy link
Contributor Author

I made a pass. I think once this PR adds unit test for both the Triton and PagedAttention kernels it should be good to go. You might also need to run clang-format to fix the merge conflict.

Thanks @simon-mo for the review. I'll add the mising unitests today.

@simon-mo
Copy link
Collaborator

I have tested the PR locally as well.

@simon-mo simon-mo merged commit 8e192ff into vllm-project:main May 25, 2024
63 of 65 checks passed
@AllenDou
Copy link
Contributor

AllenDou commented May 27, 2024

Phi-3-small's SPECIAL_TOKENS('<|******|>') will cause guided_grammar crash

  File "/root/vllm/vllm/engine/async_llm_engine.py", line 39, in _raise_exception_on_finish
    task.result()
  File "/root/vllm/vllm/engine/async_llm_engine.py", line 517, in run_engine_loop
    has_requests_in_progress = await asyncio.wait_for(
  File "/usr/lib/python3.10/asyncio/tasks.py", line 445, in wait_for
    return fut.result()
  File "/root/vllm/vllm/engine/async_llm_engine.py", line 491, in engine_step
    request_outputs = await self.engine.step_async()
  File "/root/vllm/vllm/engine/async_llm_engine.py", line 225, in step_async
    output = await self.model_executor.execute_model_async(
  File "/root/vllm/vllm/executor/gpu_executor.py", line 117, in execute_model_async
    output = await make_async(self.driver_worker.execute_model
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/vllm/vllm/worker/worker.py", line 272, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/vllm/vllm/worker/model_runner.py", line 709, in execute_model
    logits = self.model.compute_logits(hidden_states, sampling_metadata)
  File "/root/vllm/vllm/model_executor/models/phi3_small.py", line 403, in compute_logits
    logits = self.logits_processor(self.lm_head.weight, hidden_states,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/vllm/vllm/model_executor/layers/logits_processor.py", line 58, in forward
    logits = _apply_logits_processors(logits, sampling_metadata)
  File "/root/vllm/vllm/model_executor/layers/logits_processor.py", line 115, in _apply_logits_processors
    logits_row = logits_processor(past_tokens_ids,
  File "/root/vllm/vllm/model_executor/guided_decoding/outlines_logits_processors.py", line 53, in __call__
    allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
  File "/usr/local/lib/python3.10/dist-packages/outlines/fsm/fsm.py", line 329, in allowed_token_ids
    self.regex_fsm = RegexFSM(regex_string, self.tokenizer)
  File "/usr/local/lib/python3.10/dist-packages/outlines/fsm/fsm.py", line 123, in __init__
    regex_string, tuple(sorted(tokenizer.vocabulary.items()))
TypeError: '<' not supported between instances of 'str' and 'bytes'

server:
python3 -m vllm.entrypoints.openai.api_server --model microsoft/Phi-3-small-8k-instruct --tensor-parallel-size 1 --served-model-name modelx --disable-log-stats --trust-remote-code

client:

curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "modelx",
        "prompt": ["Generate a sql state that select col_1 from table_1 where it is equals to 1"],
        "max_tokens": 20,
        "temperature": 0,
        "guided_grammar": "start: select_statement\r\nselect_statement: \"SELECT\" column \"from\" table \"where\" condition\r\ncolumn: \"col_1\" | \"col_2\"\r\ntable: \"table_1\" | \"table_2\"\r\ncondition: column \"=\" number\r\nnumber: \"1\" | \"2\""
    }'

#5068 add a test case.

dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 31, 2024
…-Small model (vllm-project#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 8, 2024
…-Small model (vllm-project#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
…-Small model (vllm-project#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 14, 2024
…-Small model (vllm-project#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…-Small model (vllm-project#4799)

Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants