Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip/spmd] Serialization Optimization #6903

Closed
wants to merge 14 commits into from
Prev Previous commit
Next Next commit
.
  • Loading branch information
rkooo567 committed Jul 30, 2024
commit dc7c4459d6d24221b8f227eb5b1441fba67b57d0
58 changes: 49 additions & 9 deletions a.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time
import sys
from array import array
from vllm.sequence import ExecuteModelRequest, SequenceData
from vllm.sequence import ExecuteModelRequest, SequenceData, SequenceDataDelta, SequenceStage
import msgspec

with open('example.bin', 'rb') as file:
Expand All @@ -11,19 +13,57 @@ def dec_hook(type, obj):
deserialized = array('l')
deserialized.frombytes(obj)
return deserialized

# decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, dec_hook=dec_hook)


# print(decoder.decode(data))

def enc_hook(obj):
if isinstance(obj, array):
# convert the complex to a tuple of real, imag
return obj.tobytes()

class Timer:
def __init__(self, msg):
self.msg = msg

def __enter__(self):
self.start = time.time()
return self # This allows access to the instance in the 'as' part of the context manager

def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
self.elapsed_us = (self.end - self.start) * 1000 * 1000
print(f"{self.msg=}. Elapsed time: {self.elapsed_us:.2f} us")

# encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
# decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, dec_hook=dec_hook)

# with Timer("Serialization"):
# serialized = encoder.encode(data)
# print(f"{sys.getsizeof(data)=}")
# with Timer("Deserialization original"):
# decoder.decode(data)
# with Timer("Deserialization original"):
# data = decoder.decode(data)

# with Timer("Serialization, big block tables"):
# data = encoder.encode(data)
# with Timer("Deserialization, big block tables"):
# data = decoder.decode(data)

# for i, metadata in enumerate(data.seq_group_metadata_list):
# for key, value in metadata.block_tables.items():
# metadata.block_tables[key] = [i]

# with Timer("Serialization, small block tables"):
# data = encoder.encode(data)
# with Timer("Deserialization, small block tables"):
# data = decoder.decode(data)

# print(decoder.decode(encoder.encode(data)))

encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
decoder = msgspec.msgpack.Decoder(SequenceData, dec_hook=dec_hook)
decoder = msgspec.msgpack.Decoder(SequenceDataDelta, dec_hook=dec_hook)

data = SequenceData([1, 2, 3])
print(decoder.decode(encoder.encode(data)))
data = SequenceDataDelta([i for i in range(2048)], 0, 0, SequenceStage.DECODE)
with Timer("Serialization, big block tables"):
data = encoder.encode(data)
with Timer("Deserialization, big block tables"):
data = decoder.decode(data)
72 changes: 72 additions & 0 deletions b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time
from array import array

def t():
l = [i for i in range(256)]
s = time.time()
a = array('l')
a.fromlist(l)
print((time.time() - s) * 1000 * 1000, "us")

t()


import msgspec

def dec_hook(type, obj):
# `type` here is the value of the custom type annotation being decoded.
if type is array:
deserialized = array('l')
deserialized.frombytes(obj)
return deserialized

def enc_hook(obj):
if isinstance(obj, array):
# convert the complex to a tuple of real, imag
return obj.tobytes()

class Timer:
def __init__(self, msg):
self.msg = msg

def __enter__(self):
self.start = time.time()
return self # This allows access to the instance in the 'as' part of the context manager

def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
self.elapsed_us = (self.end - self.start) * 1000 * 1000
print(f"{self.msg=}. Elapsed time: {self.elapsed_us:.2f} us")

encoder = msgspec.msgpack.Encoder(enc_hook=enc_hook)
decoder = msgspec.msgpack.Decoder(dec_hook=dec_hook)

l = [i for i in range(256)]
d = {"1": l}


with Timer("Serialization array"):
# a = array('l')
# a.fromlist(l)
data = encoder.encode(a)
with Timer("Deserialization"):
data = decoder.decode(data)

l = [i for i in range(256)]
a = array('l')
a.fromlist(l)


with Timer("Serialization bigger array"):
# a = array('l')
# a.fromlist(l)
data = encoder.encode(a)
with Timer("Deserialization"):
data = decoder.decode(data)


# for _ in range(5):
# with Timer("Serialization list"):
# data = encoder.encode(l)
# with Timer("Deserialization"):
# data = decoder.decode(data)
1 change: 1 addition & 0 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def run_vllm(
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
max_num_seqs=32,
)

# Add the requests to the engine.
Expand Down
Binary file added example.bin
Binary file not shown.
18 changes: 15 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def __init__(
if self.enable_artificial_preemption
else 0)
self.num_cumulative_preemption: int = 0
from collections import defaultdict
self._block_table_cache: Dict[int, Dict[int, List[int]]] = defaultdict(dict)

@property
def lora_enabled(self) -> bool:
Expand Down Expand Up @@ -993,18 +995,29 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}

is_prompt = seq_group.is_prefill()
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt or not envs.VLLM_USE_RAY_SPMD_WORKER:
block_table = self.block_manager.get_block_table(seq)
block_tables[seq_id] = block_table
self._block_table_cache[seq_group.request_id][seq_id] = block_table
else:
block_table = self.block_manager.get_block_table(seq)
if len(self._block_table_cache[seq_group.request_id][seq_id]) < len(block_table):
block_tables[seq_id] = [block_table[-1]]
self._block_table_cache[seq_group.request_id][seq_id].append(block_table[-1])
else:
block_tables[seq_id] = []
self.block_manager.access_all_blocks_in_seq(seq, now)

common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))

do_sample = True
if seq_group.is_prefill():
if is_prompt:
seqs = seq_group.get_seqs()
# Prefill has only 1 sequence.
assert len(seqs) == 1
Expand All @@ -1019,7 +1032,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
if is_prompt or not envs.VLLM_USE_RAY_SPMD_WORKER:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
Expand Down
8 changes: 5 additions & 3 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,11 @@ def execute_model(

serialized_data = self.encoder.encode(execute_model_req)
# # Open a file in binary write mode
# with open('example.bin', 'wb') as file:
# # Write bytes to the file
# file.write(serialized_data)
import sys
if sys.getsizeof(serialized_data) > 60000:
with open('example.bin', 'wb') as file:
# Write bytes to the file
file.write(serialized_data)

# print(f"SANG-TODO input serialization takes {(time.time() - s) * 1000} ms index: {self.i}")

Expand Down
13 changes: 7 additions & 6 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SequenceDataDelta(msgspec.Struct, array_like=True, omit_defaults=True):
new_stage: SequenceStage


class SequenceData(msgspec.Struct, array_like=False, omit_defaults=True):
class SequenceData(msgspec.Struct, omit_defaults=True):
"""Data associated with a sequence.

Args:
Expand All @@ -125,7 +125,7 @@ class SequenceData(msgspec.Struct, array_like=False, omit_defaults=True):
_prompt_token_ids: array
_output_token_ids: Optional[array] = None

## The below fields should not be passed as an argument ##
### The below fields should not be passed as an argument ###
cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: Optional[Tuple[int, ...]] = None
# The number of tokens that are computed (that run against the model).
Expand Down Expand Up @@ -661,7 +661,6 @@ def __repr__(self) -> str:

class SequenceGroupMetadataDecode(msgspec.Struct, tag=True, array_like=True, omit_defaults=True):
"""Delta sequence group metadata."""

seq_data_delta: Dict[int, SequenceDataDelta]
request_id: str
block_tables: Dict[int, List[int]]
Expand Down Expand Up @@ -718,7 +717,7 @@ class SequenceGroupMetadata(msgspec.Struct, tag=True, array_like=True, omit_defa
# prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None

## Stateful fields that are lazily defined. ##
### Stateful fields that are lazily defined. ###
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
Expand Down Expand Up @@ -746,13 +745,15 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0


def apply_delta(
self, sequence_group_metadata_decode: SequenceGroupMetadataDecode):
for id, delta in sequence_group_metadata_decode.seq_data_delta.items():
self.seq_data[id].apply_delta(delta)
self.request_id = sequence_group_metadata_decode.request_id
self.block_tables = sequence_group_metadata_decode.block_tables
for seq_id, block_table in sequence_group_metadata_decode.block_tables.items():
if len(block_table) > 0:
self.block_tables[seq_id].append(block_table[0])
# self.block_tables = sequence_group_metadata_decode.block_tables
self.token_chunk_size = sequence_group_metadata_decode.token_chunk_size
self.do_sample = sequence_group_metadata_decode.do_sample
self.is_prompt = False
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ def execute_model(
# output is List[SamplerOutput]
return output

# 1. spmd -> general & used by other backend easier
# 2.

def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
Expand Down