Skip to content

Commit

Permalink
Use NCCL instead of ray for control-plane communication to remove ser…
Browse files Browse the repository at this point in the history
…ialization overhead (vllm-project#2221)
  • Loading branch information
zhuohan123 authored Jan 3, 2024
1 parent 1066cbd commit fd4ea8e
Show file tree
Hide file tree
Showing 34 changed files with 519 additions and 257 deletions.
7 changes: 3 additions & 4 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ positions: torch.Tensor,
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
+) -> SamplerOutput:
+) -> Optional[SamplerOutput]:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.

.. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
Expand Down
2 changes: 0 additions & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ typing-extensions>=4.8.0
starlette
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
ninja # For faster builds.
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
Expand Down
13 changes: 9 additions & 4 deletions tests/async_engine/test_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,22 @@
import requests


def _query_server(prompt: str) -> dict:
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
response = requests.post("http://localhost:8000/generate",
json={
"prompt": prompt,
"max_tokens": 100,
"max_tokens": max_tokens,
"temperature": 0,
"ignore_eos": True
})
response.raise_for_status()
return response.json()


def _query_server_long(prompt: str) -> dict:
return _query_server(prompt, max_tokens=500)


@pytest.fixture
def api_server():
script_path = Path(__file__).parent.joinpath(
Expand Down Expand Up @@ -68,10 +72,11 @@ def test_api_server(api_server):
for result in pool.map(_query_server, prompts):
assert result

with Pool(32) as pool:
# Cancel requests
prompts = ["canceled requests"] * 100
pool.map_async(_query_server, prompts)
time.sleep(0.001)
pool.map_async(_query_server_long, prompts)
time.sleep(0.01)
pool.terminate()
pool.join()

Expand Down
24 changes: 12 additions & 12 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {}
copy_src = []
copy_dst = []
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i])
copy_src.append(src_blocks[i])
copy_dst.append(dst_blocks[2 * i + 1])

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
Expand All @@ -66,15 +67,14 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]

# Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)

# Run the reference implementation.
for src, dsts in block_mapping.items():
for dst in dsts:
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])
for src, dst in zip(copy_src, copy_dst):
for cloned_key_cache in cloned_key_caches:
cloned_key_cache[dst].copy_(cloned_key_cache[src])
for cloned_value_cache in cloned_value_caches:
cloned_value_cache[dst].copy_(cloned_value_cache[src])

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list)
input_tokens, input_positions, _, return_prompt_lens = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
Expand Down
61 changes: 33 additions & 28 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,45 +185,51 @@ async def step_async(self) -> List[RequestOutput]:
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()

# Execute the model.
output = (await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)) if not scheduler_outputs.is_empty() else []
if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})

# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []

return self._process_model_outputs(output, scheduler_outputs)

async def _run_workers_async(
self,
method: str,
*args,
get_all_outputs: bool = False,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
for worker in self.workers:
if self.parallel_config.worker_use_ray:
coros.append(
worker.execute_method.remote(method, *args, **kwargs))
else:
executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))

all_outputs = await asyncio.gather(*coros)
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

if get_all_outputs:
return all_outputs
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))

# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs


class AsyncLLMEngine:
Expand Down Expand Up @@ -488,13 +494,12 @@ def from_engine_args(cls,
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster(
parallel_config, engine_args.engine_use_ray)
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
distributed_init_method,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
Expand Down
Loading

0 comments on commit fd4ea8e

Please sign in to comment.