-
Notifications
You must be signed in to change notification settings - Fork 587
Closed
Description
We have tested sglang with flashinfer 0.0.2 and flashinfer 0.0.3-dev (238563f) and both will crash in flashinfer with following stacktrace under A100.
Model: Yi-34B
OS: Ubuntu 22.04
Gpu: A100 80GB
Yi-6B and Yi-9B has no such issue. Yi is llama2 based arch if I am not mistaken.
@yzh119 Since the stacktrace is vague to me, BatchPrefillWithPagedKVCache failed to dispatch with dtype Half, I am first reproting the bug here. If you think this is sglang related, I will move bug to sglang. Thanks!
Traceback (most recent call last):
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 184, in exposed_step
self.forward_step()
File "/root/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 199, in forward_step
self.forward_fill_batch(new_batch)
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_rpc.py", line 412, in forward_fill_batch
) = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 506, in forward
return self.forward_extend(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/router/model_runner.py", line 411, in forward_extend
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 269, in forward
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 239, in forward
hidden_states, residual = layer(
^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 191, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/models/llama2.py", line 140, in forward
attn_output = self.attn(q, k, v, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 115, in forward
return self.extend_forward(q, k, v, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/layers/radix_attention.py", line 91, in prefill_forward_flashinfer
o = input_metadata.prefill_wrapper.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/lib/python3.11/site-packages/flashinfer/prefill.py", line 507, in forward
return self._wrapper.forward(
^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: BatchPrefillWithPagedKVCache failed to dispatch with dtype Half
Metadata
Metadata
Assignees
Labels
No labels