Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip/spmd] Serialization Optimization #6903

Closed
wants to merge 14 commits into from
Prev Previous commit
working e2e
  • Loading branch information
rkooo567 committed Aug 3, 2024
commit 1e6196bd5bc500186cef8b1a110fdd3042cbbe32
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
Expand Down
39 changes: 20 additions & 19 deletions b.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from array import array
from vllm.sequence import SequenceData


def t():
Expand Down Expand Up @@ -47,29 +48,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
decoder = msgspec.msgpack.Decoder(dec_hook=dec_hook)

l = [i for i in range(256)]
d = {"1": l}
# l = [i for i in range(256)]
# d = {"1": l}

with Timer("Serialization array"):
# a = array('l')
# a.fromlist(l)
data = encoder.encode(a)
with Timer("Deserialization"):
data = decoder.decode(data)
# with Timer("Serialization array"):
# # a = array('l')
# # a.fromlist(l)
# data = encoder.encode(a)
# with Timer("Deserialization"):
# data = decoder.decode(data)

l = [i for i in range(256)]
a = array('l')
l = [i for i in range(64 * 256)]
a = array('I')
a.fromlist(l)
# a = SequenceData(a)

with Timer("Serialization bigger array"):
# a = array('l')
# a.fromlist(l)
# with Timer("Serialization sequence data"):
# # a = array('l')
# # a.fromlist(l)
# data = encoder.encode(a)
# with Timer("Deserialization"):
# data = decoder.decode(data)

with Timer("Serialization array"):
data = encoder.encode(a)
with Timer("Deserialization"):
data = decoder.decode(data)

# for _ in range(5):
# with Timer("Serialization list"):
# data = encoder.encode(l)
# with Timer("Deserialization"):
# data = decoder.decode(data)
20 changes: 20 additions & 0 deletions c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import time
import numpy as np

class Timer:

def __init__(self, msg):
self.msg = msg

def __enter__(self):
self.start = time.time()
return self # This allows access to the instance in the 'as' part of the context manager

def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
self.elapsed_us = (self.end - self.start) * 1000 * 1000
print(f"{self.msg=}. Elapsed time: {self.elapsed_us:.2f} us")
l = [i for i in range(4096)]
from array import array
with Timer("converesion"):
arr = array("I", l)
9 changes: 1 addition & 8 deletions tests/prompts/example.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1 @@
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
Describe the basic components of a neural network and how it can be trained.
Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
1 change: 0 additions & 1 deletion vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from dataclasses import dataclass


class AdapterRequest:
Expand Down
59 changes: 25 additions & 34 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDecode,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)

logger = init_logger(__name__)
Expand Down Expand Up @@ -333,9 +333,6 @@ def __init__(
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0
from collections import defaultdict
self._block_table_cache: Dict[int, Dict[int,
List[int]]] = defaultdict(dict)

@property
def lora_enabled(self) -> bool:
Expand Down Expand Up @@ -998,80 +995,74 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}

is_prompt = seq_group.is_prefill()
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
if is_prompt or not envs.VLLM_USE_RAY_SPMD_WORKER or True:
block_table = self.block_manager.get_block_table(seq)
block_tables[seq_id] = block_table
self._block_table_cache[
seq_group.request_id][seq_id] = block_table
else:
block_table = self.block_manager.get_block_table(seq)
if len(self._block_table_cache[seq_group.request_id]
[seq_id]) < len(block_table):
block_tables[seq_id] = [block_table[-1]]
self._block_table_cache[
seq_group.request_id][seq_id].append(
block_table[-1])
else:
block_tables[seq_id] = []
block_table = self.block_manager.get_block_table(seq)
block_tables[seq_id] = block_table
self.block_manager.access_all_blocks_in_seq(seq, now)

common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))

do_sample = True
is_prompt = seq_group.is_prefill()
# We should send the metadata to workers when the first prefill
# is sent. Subsequent requests could be chunked prefill or decode.
is_first_prefill = False
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
num_computed_tokens = seqs[0].data.get_num_computed_tokens()
is_first_prefill = num_computed_tokens == 0
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + seqs[0].data.get_num_computed_tokens() <
if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()):
do_sample = False

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_prompt or not envs.VLLM_USE_RAY_SPMD_WORKER:
# When SPMD mode is enabled, we only send delta data except for
# the first request to reduce serialization cost.
if is_first_prefill or not envs.VLLM_USE_RAY_SPMD_WORKER:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
# pooling_params=seq_group.pooling_params,
pooling_params=seq_group.pooling_params,
token_chunk_size=token_chunk_size,
# lora_request=seq_group.lora_request,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
# state=seq_group.state,
# # `multi_modal_data` will only be present for the 1st comm
# # between engine and worker.
# # the subsequent comms can still use delta, but
# # `multi_modal_data` will be None.
# multi_modal_data=seq_group.multi_modal_data
# if scheduler_outputs.num_prefill_groups > 0 else None,
# prompt_adapter_request=seq_group.prompt_adapter_request,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
else:
# Delta is used only for spmd workers.
seq_data_delta = {}
for id, data in seq_data.items():
seq_data_delta[id] = data.get_delta()

seq_group_metadata = SequenceGroupMetadataDecode(
seq_group_metadata = SequenceGroupMetadataDelta(
seq_data_delta,
seq_group.request_id,
block_tables,
is_prompt,
do_sample=do_sample,
token_chunk_size=token_chunk_size,
computed_block_nums=common_computed_block_nums,
)
seq_group_metadata_list.append(seq_group_metadata)

Expand Down
34 changes: 17 additions & 17 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import time
import msgspec
import pickle
from array import array

from regex import P

import vllm.envs as envs
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
Expand Down Expand Up @@ -68,14 +64,14 @@ def _init_executor(self) -> None:
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)

self.forward_dag: Optional["ray.dag.CompiledDAG"] = None

def enc_hook(obj: Any) -> Any:
if isinstance(obj, array):
# convert the complex to a tuple of real, imag
return obj.tobytes()

self.encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
self.output_decoder = msgspec.msgpack.Decoder(
Optional[List[SamplerOutput]])

def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
Expand Down Expand Up @@ -290,22 +286,26 @@ def execute_model(
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

s = time.time()
# s = time.time()
# import pickle
# serialized_data = pickle.dumps(execute_model_req)

serialized_data = self.encoder.encode(execute_model_req)
serialized_data = self.input_encoder.encode(execute_model_req)
# # Open a file in binary write mode
import sys
if sys.getsizeof(serialized_data) > 60000:
with open('example.bin', 'wb') as file:
# Write bytes to the file
file.write(serialized_data)
# import sys
# if sys.getsizeof(serialized_data) > 60000:
# with open('example.bin', 'wb') as file:
# # Write bytes to the file
# file.write(serialized_data)

# print(f"SANG-TODO input serialization takes {(time.time() - s) * 1000} ms index: {self.i}")
# print("SANG-TODO input serialization takes "
# f"{(time.time() - s) * 1000} ms index: {self.i}")

outputs = ray.get(self.forward_dag.execute(serialized_data))
output = pickle.loads(outputs[0])
# print(f"SANG-TODO e2e takes {(time.time() - s) * 1000} ms index: {self.i}")
# output = pickle.loads(outputs[0])
output = self.output_decoder.decode(outputs[0])
# print(f"SANG-TODO e2e takes {(time.time() - s) * 1000} "
# f"ms index: {self.i}")
self.i += 1
return output

Expand Down
33 changes: 19 additions & 14 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import List, Optional, Tuple, Type, Any
import time
import pickle
import msgspec
from array import array

Expand Down Expand Up @@ -30,12 +28,13 @@ def __init__(self, *args, **kwargs) -> None:

def dec_hook(type: Type, obj: Any) -> Any:
if type is array:
deserialized = array('l')
deserialized = array('I')
deserialized.frombytes(obj)
return deserialized

self.decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=dec_hook)
self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
dec_hook=dec_hook)
self.output_encoder = msgspec.msgpack.Encoder()

def get_node_ip(self) -> str:
return get_ip()
Expand All @@ -45,25 +44,31 @@ def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids

def execute_model_spmd(self, execute_model_req: bytes):
def execute_model_spmd(self, serialized_execute_model_req: bytes):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
s = time.time()

execute_model_req: ExecuteModelRequest = self.decoder.decode(
execute_model_req)
# execute_model_req: ExecuteModelRequest = pickle.loads(execute_model_req)
# print(f"SANG-TODO input deserialization takes {(time.time() - s) * 1000} ms index: {self.i}")
# s = time.time()

execute_model_req: ExecuteModelRequest = self.input_decoder.decode(
serialized_execute_model_req)
# import pickle
# execute_model_req: ExecuteModelRequest = (
# pickle.loads(execute_model_req))
# print("SANG-TODO input deserialization takes "
# f"{(time.time() - s) * 1000} ms index: {self.i}")
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True

output = self.worker._execute_model_spmd(execute_model_req)
output = pickle.dumps(output)
# print(f"SANG-TODO worker takes {(time.time() - s) * 1000} ms index: {self.i}")
# output = pickle.dumps(output)
output = self.output_encoder.encode(output)
# print("SANG-TODO worker takes "
# f"{(time.time() - s) * 1000} ms index: {self.i}")
self.i += 1
return output

Expand Down
3 changes: 2 additions & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from array import array
import functools
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
Expand Down Expand Up @@ -106,7 +107,7 @@ def _default_dummy_data_factory(
# Avoid circular import
from vllm.sequence import SequenceData

dummy_seq_data = SequenceData([0] * seq_len)
dummy_seq_data = SequenceData(array("I", [0] * seq_len))
dummy_multi_modal_data = None

return dummy_seq_data, dummy_multi_modal_data
Expand Down
6 changes: 4 additions & 2 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import warnings
from dataclasses import dataclass, field
from typing import Optional

from vllm.adapter_commons.request import AdapterRequest
import msgspec


class LoRARequest(msgspec.Struct, AdapterRequest):
class LoRARequest(msgspec.Struct,
AdapterRequest,
omit_defaults=True,
array_like=True):
"""
Request for a LoRA adapter.

Expand Down
Loading