Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NCCL instead of ray for control-plane communication to remove serialization overhead #2221

Merged
merged 35 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
7265829
small test
zhuohan123 Dec 18, 2023
20274cc
test ray_pg
zhuohan123 Dec 19, 2023
1b73dd7
update ray test
zhuohan123 Dec 19, 2023
0d89354
implement driver worker
zhuohan123 Dec 20, 2023
e0c4c4e
broadcast swap info
zhuohan123 Dec 20, 2023
1baf87b
Broadcast inputmetadata as well
zhuohan123 Dec 20, 2023
c947fa0
fix bugs
zhuohan123 Dec 20, 2023
761584b
fix comments
zhuohan123 Dec 25, 2023
19110fb
remove unused files
zhuohan123 Dec 25, 2023
7b05ec6
fix async llm engine
zhuohan123 Dec 26, 2023
5f90351
fix format
zhuohan123 Dec 26, 2023
6f7ea32
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
966e366
[BUGFIX] Fix API server test
zhuohan123 Dec 26, 2023
fe2c29a
fix and remove print
zhuohan123 Dec 26, 2023
5557cdb
fix test_cache
zhuohan123 Dec 26, 2023
d92b38d
Merge branch 'fix-test-api-server' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
c7f6c21
fix api test
zhuohan123 Dec 26, 2023
332d370
[BUGFIX] Fix the path of test prompts
zhuohan123 Dec 26, 2023
9a8c16f
Merge branch 'fix-test-prompt-path' into remove-serialization-overhead
zhuohan123 Dec 26, 2023
6ea2a42
fix test_model_runner
zhuohan123 Dec 26, 2023
0434a76
Merge branch 'main' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
95bb1d3
Fix async llm engine
zhuohan123 Dec 27, 2023
de4c8d2
[BUGFIX] Fix communication test
zhuohan123 Dec 27, 2023
89d7cfd
Merge branch 'fix-comm-test-2' into remove-serialization-overhead
zhuohan123 Dec 27, 2023
2b4863a
style
zhuohan123 Dec 27, 2023
3096c56
Fix smaller review comments
zhuohan123 Dec 28, 2023
dc4a4c2
fix
zhuohan123 Dec 28, 2023
f2b8e88
remove unused files
zhuohan123 Dec 28, 2023
83c2735
fix review comments
zhuohan123 Dec 28, 2023
3d3a547
allgather -> gather
zhuohan123 Jan 3, 2024
680c8d9
fix
zhuohan123 Jan 3, 2024
5280a61
fix and revert unnecessary changes
zhuohan123 Jan 3, 2024
03b2734
fix
zhuohan123 Jan 3, 2024
0ca5e07
fix
zhuohan123 Jan 3, 2024
ddb0795
fix review comments
zhuohan123 Jan 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers);

void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers);

void reshape_and_cache(
torch::Tensor& key,
Expand Down
27 changes: 14 additions & 13 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers) {
assert(src_block_numbers.size() == dst_block_numbers.size());
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
Expand All @@ -35,9 +37,9 @@ void swap_blocks(
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
for (int64_t i = 0; i < src_block_numbers.size(); ++i) {
int64_t src_block_number = src_block_numbers[i];
int64_t dst_block_number = dst_block_numbers[i];
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
Expand Down Expand Up @@ -85,7 +87,8 @@ __global__ void copy_blocks_kernel(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
const std::vector<int64_t>& src_block_numbers,
const std::vector<int64_t>& dst_block_numbers) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
Expand All @@ -104,12 +107,10 @@ void copy_blocks(
}
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
assert(src_block_numbers.size() == dst_block_numbers.size());
for (int i = 0; i < src_block_numbers.size(); ++i) {
block_mapping_vec.push_back(src_block_numbers[i]);
block_mapping_vec.push_back(dst_block_numbers[i]);
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
Expand Down Expand Up @@ -252,12 +253,12 @@ __global__ void gather_cached_kv_kernel(
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
const int tgt_key_idx = token_idx * key_stride + i;
const int tgt_value_idx = token_idx * value_stride + i;

const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
const int x_offset = head_offset % x;

const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
zhuohan123 marked this conversation as resolved.
Show resolved Hide resolved
+) -> SamplerOutput:
+) -> Optional[SamplerOutput]:

3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
Expand Down
17 changes: 17 additions & 0 deletions examples/offline_inference_long.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"hi" * 90000,
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
llm = LLM(model="mistralai/Mistral-7B-v0.1", max_model_len=160000)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
print(f"Prompt len: {len(output.prompt_token_ids)}, Generated text: {output.outputs[0].text!r}")
65 changes: 65 additions & 0 deletions playground/test_ray_placement_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import time
import os

# Import placement group APIs.
from ray.util.placement_group import (
placement_group,
placement_group_table,
remove_placement_group,
)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

# Initialize Ray.
import ray

class NormalActor:
def __init__(self, index):
self.index = index
pass

def log_message(self):
import torch
print("NormalActor", self.index, os.getpid(), torch.cuda.is_available(), ray.get_gpu_ids())

class AllocationActor:
def __init__(self, pg):
self.placement_group = pg
self.a2 = ray.remote(num_cpus=1)(NormalActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=1,
)
).remote(1)
self.a3 = ray.remote(num_gpus=1, num_cpus=0)(NormalActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=2,
)
).remote(2)

def log_message(self):
print("AllocationActor", os.getpid())
ray.get([self.a2.log_message.remote(), self.a3.log_message.remote()])


def main():
# Create a single node Ray cluster with 2 CPUs and 2 GPUs.
ray.init(num_cpus=2, num_gpus=1)

print(ray.cluster_resources())

# Reserve a placement group of 1 bundle that reserves 1 CPU and 1 GPU.
pg = placement_group([{"CPU": 1}, {"CPU": 1}, {"GPU": 1, "CPU": 0, "node:__internal_head__": 1e-2}])

ray.get(pg.ready())
a1 = ray.remote(num_cpus=1)(AllocationActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=0,
)
).remote(pg)

ray.get(a1.log_message.remote())
print(ray.available_resources())

main()
Loading