Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NCCL instead of ray for control-plane communication to remove serialization overhead #2221

Merged
merged 35 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7265829
small test
zhuohan123 Dec 18, 2023
20274cc
test ray_pg
zhuohan123 Dec 19, 2023
1b73dd7
update ray test
zhuohan123 Dec 19, 2023
0d89354
implement driver worker
zhuohan123 Dec 20, 2023
e0c4c4e
broadcast swap info
zhuohan123 Dec 20, 2023
1baf87b
Broadcast inputmetadata as well
zhuohan123 Dec 20, 2023
c947fa0
fix bugs
zhuohan123 Dec 20, 2023
761584b
fix comments
zhuohan123 Dec 25, 2023
19110fb
remove unused files
zhuohan123 Dec 25, 2023
7b05ec6
fix async llm engine
zhuohan123 Dec 26, 2023
5f90351
fix format
zhuohan123 Dec 26, 2023
6f7ea32
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
966e366
[BUGFIX] Fix API server test
zhuohan123 Dec 26, 2023
fe2c29a
fix and remove print
zhuohan123 Dec 26, 2023
5557cdb
fix test_cache
zhuohan123 Dec 26, 2023
d92b38d
Merge branch 'fix-test-api-server' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
c7f6c21
fix api test
zhuohan123 Dec 26, 2023
332d370
[BUGFIX] Fix the path of test prompts
zhuohan123 Dec 26, 2023
9a8c16f
Merge branch 'fix-test-prompt-path' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
6ea2a42
fix test_model_runner
zhuohan123 Dec 26, 2023
0434a76
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
95bb1d3
Fix async llm engine
zhuohan123 Dec 27, 2023
de4c8d2
[BUGFIX] Fix communication test
zhuohan123 Dec 27, 2023
89d7cfd
Merge branch 'fix-comm-test-2' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
2b4863a
style
zhuohan123 Dec 27, 2023
3096c56
Fix smaller review comments
zhuohan123 Dec 28, 2023
dc4a4c2
fix
zhuohan123 Dec 28, 2023
f2b8e88
remove unused files
zhuohan123 Dec 28, 2023
83c2735
fix review comments
zhuohan123 Dec 28, 2023
3d3a547
allgather -> gather
zhuohan123 Jan 3, 2024
680c8d9
fix
zhuohan123 Jan 3, 2024
5280a61
fix and revert unnecessary changes
zhuohan123 Jan 3, 2024
03b2734
fix
zhuohan123 Jan 3, 2024
0ca5e07
fix
zhuohan123 Jan 3, 2024
ddb0795
fix review comments
zhuohan123 Jan 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ positions: torch.Tensor,
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
+) -> SamplerOutput:
+) -> Optional[SamplerOutput]:

3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.

.. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
Expand Down
2 changes: 0 additions & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ typing-extensions>=4.8.0
starlette
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
ninja # For faster builds.
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
Expand Down
13 changes: 9 additions & 4 deletions tests/async_engine/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import requests


def _query_server(prompt: str) -> dict:
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
response = requests.post("http://localhost:8000/generate",
json={
"prompt": prompt,
"max_tokens": 100,
"max_tokens": max_tokens,
"temperature": 0,
"ignore_eos": True
})
response.raise_for_status()
return response.json()


def _query_server_long(prompt: str) -> dict:
return _query_server(prompt, max_tokens=500)


@pytest.fixture
def api_server():
script_path = Path(__file__).parent.joinpath(
Expand Down Expand Up @@ -68,10 +72,11 @@ def test_api_server(api_server):
for result in pool.map(_query_server, prompts):
assert result

with Pool(32) as pool:
# Cancel requests
prompts = ["canceled requests"] * 100
pool.map_async(_query_server, prompts)
time.sleep(0.001)
pool.map_async(_query_server_long, prompts)
time.sleep(0.01)
pool.terminate()
pool.join()

Expand Down
2 changes: 1 addition & 1 deletion tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from vllm.config import ParallelConfig
from vllm.engine.ray_utils import get_open_port
from vllm.utils import get_open_port
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
Expand Down
24 changes: 12 additions & 12 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
copy_src = []
copy_dst = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i])
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i + 1])

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
Expand All @@ -63,15 +64,14 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]

# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)

# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in zip(copy_src, copy_dst):
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list)
input_tokens, input_positions, _, return_prompt_lens = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
Expand Down
61 changes: 33 additions & 28 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,45 +185,51 @@ async def step_async(self) -> List[RequestOutput]:
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = (await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)) if not scheduler_outputs.is_empty() else []
if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})

# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []

return self._process_model_outputs(output, scheduler_outputs)

async def _run_workers_async(
self,
method: str,
*args,
get_all_outputs: bool = False,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
for worker in self.workers:
if self.parallel_config.worker_use_ray:
coros.append(
worker.execute_method.remote(method, *args, **kwargs))
else:
executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))

all_outputs = await asyncio.gather(*coros)
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

if get_all_outputs:
return all_outputs
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))

# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs


class AsyncLLMEngine:
Expand Down Expand Up @@ -488,13 +494,12 @@ def from_engine_args(cls,
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray)
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
distributed_init_method,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
Loading
Loading