Skip to content

Commit 40f0d62

Browse files
SiyangShaoShao Siyang FYP PDCLlrq619
committed
Dev/llama3 (vllm-project#7)
* llama support * flash_attention * sharded * expend * fix: remove redunctant info * change main * llama and opt model supported --------- Co-authored-by: Shao Siyang FYP PDCL <shaosy@scsehg.cm.cluster> Co-authored-by: lairuiqi <lrq619@outlook.com> Co-authored-by: LaiRuiqi <58351056+lrq619@users.noreply.github.com>
1 parent ec6642c commit 40f0d62

File tree

11 files changed

+674
-105
lines changed

11 files changed

+674
-105
lines changed

main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ def main():
1616
enforce_eager=True,
1717
# load_format="auto",
1818
# tensor_parallel_size=2,
19+
# liquid_gpu_range = [0,1,2,3],
1920
liquid_gpu_range = [0,1,2,3],
2021
liquid_gpu_space = 32,
2122
liquid_driver_gpu_id = 0,
2223
liquid_total_num_shards = 4,
23-
# gpu_memory_utilization=0.8,
24+
2425
)
2526
sampling_params = SamplingParams(temperature=0, min_tokens=128, max_tokens=128)
2627
request_num = 1
@@ -37,8 +38,7 @@ def main():
3738
llm.do_liquid(liquid_request)
3839
liquid_request = LiquidRequest(LiquidType.LIQUID_2_1)
3940
llm.do_liquid(liquid_request)
40-
# liquid_request = LiquidRequest(LiquidType.LIQUID_1_2)
41-
# llm.do_liquid(liquid_request)
41+
4242

4343

4444
output = llm.generate(inputs, sampling_params=sampling_params)
@@ -53,4 +53,4 @@ def main():
5353
main()
5454
# torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle")
5555
# torch.cuda.memory._record_memory_history(enabled=None)
56-
# print(f"dumped finished!")
56+
# print(f"dumped finished!")

vanilla.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
2+
from vllm import LLM, SamplingParams
3+
from vllm.liquid.request import LiquidRequest, LiquidType
4+
# from vllm import EngineArgs, LLMEngine
5+
import asyncio
6+
import torch
7+
8+
import os
9+
10+
model = "meta-llama/Meta-Llama-3-8B"
11+
# model = "facebook/opt-6.7b"
12+
# model_path = os.path.join("./models", model)
13+
14+
def main():
15+
llm = LLM(
16+
model,
17+
enforce_eager=True,
18+
# load_format="auto",
19+
tensor_parallel_size=2,
20+
# liquid_gpu_range = [0,1,2,3],
21+
# liquid_gpu_space = 32,
22+
# liquid_driver_gpu_id = 0,
23+
# liquid_total_num_shards = 4,
24+
gpu_memory_utilization=0.8,
25+
)
26+
sampling_params = SamplingParams(temperature=0, min_tokens=128, max_tokens=128)
27+
request_num = 1
28+
word = "what is LLM?"
29+
prompt = word
30+
inputs = [prompt for _ in range(request_num)]
31+
32+
# for i in range(1):
33+
# print(f"i: {i}")
34+
# liquid_request = LiquidRequest(LiquidType.LIQUID_1_2)
35+
# llm.do_liquid(liquid_request)
36+
# # liquid_request = LiquidRequest(LiquidType.LIQUID_2_4)
37+
# # llm.do_liquid(liquid_request)
38+
# # liquid_request = LiquidRequest(LiquidType.LIQUID_4_2)
39+
# # llm.do_liquid(liquid_request)
40+
# liquid_request = LiquidRequest(LiquidType.LIQUID_2_1)
41+
# llm.do_liquid(liquid_request)
42+
43+
# print("liquid done")
44+
45+
46+
output = llm.generate(inputs, sampling_params=sampling_params)
47+
print(f"output: {output[0].outputs[0].text}")
48+
49+
50+
51+
52+
53+
if __name__ == '__main__':
54+
# torch.cuda.memory._record_memory_history(context="all", stacks="all")
55+
main()
56+
# torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle")
57+
# torch.cuda.memory._record_memory_history(enabled=None)
58+
# print(f"dumped finished!")

vllm/attention/backends/flash_attn_liquid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,15 @@ def __init__(
265265

266266
def delete_shard(self, shard_id: int):
267267
assert shard_id in self.shard_ids
268-
self.num_heads -= self.num_kv_heads_per_shard
268+
self.num_heads -= self.num_heads_per_shard
269269
self.num_kv_heads -= self.num_kv_heads_per_shard
270270

271271
index = self.shard_ids.index(shard_id)
272272
self.shard_ids.pop(index)
273273

274274
def append_shard(self, shard_id: int):
275275
assert shard_id not in self.shard_ids
276-
self.num_heads += self.num_kv_heads_per_shard
276+
self.num_heads += self.num_heads_per_shard
277277
self.num_kv_heads += self.num_kv_heads_per_shard
278278
self.shard_ids.append(shard_id)
279279

vllm/engine/llm_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ def __init__(
221221
self.liquid_config = liquid_config
222222
self.liquid_request_queue: Queue[LiquidRequest] = Queue()
223223
self.execution_lock: threading.Lock = threading.Lock()
224-
self.auto_scaler = AutoScaler(liquid_config=liquid_config)
224+
if liquid_config is not None:
225+
self.auto_scaler = AutoScaler(liquid_config=liquid_config)
225226
self.request_output_queue: Queue[RequestOutput] = Queue()
226227

227228
if not self.model_config.skip_tokenizer_init:
@@ -832,8 +833,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
832833
"""
833834
# self.model_executor.delete_kv_cache()
834835
cache_usage = self.get_latest_metrics().gpu_cache_usage
835-
# liquid_request = None
836-
liquid_request = self.auto_scaler.step(cache_usage)
836+
liquid_request = None
837+
if self.liquid_config is not None:
838+
liquid_request = self.auto_scaler.step(cache_usage)
837839
if liquid_request is not None:
838840
self.liquid_request_queue.put(liquid_request)
839841

vllm/liquid/model_executor/layers/linear.py

Lines changed: 201 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.nn.functional as F
66
from torch.nn.parameter import Parameter
7-
from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter
7+
from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter,GateUpShardedParameter
88

99
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
1010
get_tensor_model_parallel_world_size,
@@ -94,15 +94,28 @@ def create_weights(self, layer: torch.nn.Module,
9494
shard_dim: int = -1,
9595
param_class = ShardedParameter,
9696
**extra_weight_attrs):
97-
weight = param_class(
98-
data=torch.empty(sum(output_partition_sizes),
99-
input_size_per_partition,
100-
dtype=params_dtype),
101-
num_shards=len(shard_ids),
102-
shard_dim=shard_dim,
103-
shard_ids=shard_ids,
104-
requires_grad=False,
105-
)
97+
if param_class == QKVShardedParameter:
98+
weight = QKVShardedParameter(
99+
data=torch.empty(sum(output_partition_sizes),
100+
input_size_per_partition,
101+
dtype=params_dtype),
102+
num_shards=len(shard_ids),
103+
shard_dim=shard_dim,
104+
shard_ids=shard_ids,
105+
requires_grad=False,
106+
num_heads_ratio=extra_weight_attrs['num_heads_ratio'],
107+
num_kv_heads_ratio=extra_weight_attrs['num_kv_heads_ratio'],
108+
)
109+
else:
110+
weight = param_class(
111+
data=torch.empty(sum(output_partition_sizes),
112+
input_size_per_partition,
113+
dtype=params_dtype),
114+
num_shards=len(shard_ids),
115+
shard_dim=shard_dim,
116+
shard_ids=shard_ids,
117+
requires_grad=False,
118+
)
106119
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
107120
layer.register_parameter("weight", weight)
108121
set_weight_attrs(weight, extra_weight_attrs)
@@ -276,6 +289,8 @@ def __init__(self,
276289
shard_ids: List[int] = [0],
277290
total_num_shards: int = 1,
278291
param_class = ShardedParameter,
292+
num_heads_ratio: int=1,
293+
num_kv_heads_ratio: int=1,
279294
):
280295
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
281296
quant_config)
@@ -310,6 +325,8 @@ def __init__(self,
310325
shard_ids=shard_ids,
311326
shard_dim=shard_dim,
312327
param_class=param_class,
328+
num_heads_ratio=num_heads_ratio,
329+
num_kv_heads_ratio=num_kv_heads_ratio,
313330
)
314331
if bias:
315332
self.bias = param_class(
@@ -446,6 +463,8 @@ def __init__(self,
446463
shard_ids=shard_ids,
447464
total_num_shards=total_num_shards,
448465
param_class=QKVShardedParameter,
466+
num_heads_ratio=self.num_heads,
467+
num_kv_heads_ratio=self.num_kv_heads,
449468
)
450469

451470
def weight_loader(self,
@@ -737,3 +756,175 @@ def extra_repr(self) -> str:
737756
s += f", tp_size={self.tp_size}"
738757
s += f", reduce_results={self.reduce_results}"
739758
return s
759+
760+
761+
class MergedColumnParallelLinear(ColumnParallelLinear):
762+
"""Packed linear layers with column parallelism.
763+
764+
Similar to ColumnParallelLinear, but the weight matrix is concatenated
765+
along the output dimension. When the weight matrix is loaded, the
766+
different partitions are sharded separately.
767+
768+
Args:
769+
input_size: input dimension of the linear layer.
770+
output_sizes: list of output dimensions of the linear layer.
771+
bias: If true, add bias.
772+
gather_output: If true, call all-gather on output and make the output
773+
available to all GPUs, otherwise, every GPU will have
774+
its own output.
775+
skip_bias_add: This was added to enable performance optimizations where
776+
bias can be fused with other element-wise operations. we
777+
skip adding bias but instead return it.
778+
params_dtype: Data type for the parameters.
779+
quant_config: Quantization configure.
780+
"""
781+
782+
def __init__(self,
783+
input_size: int,
784+
output_sizes: List[int],
785+
bias: bool = True,
786+
gather_output: bool = False,
787+
skip_bias_add: bool = False,
788+
params_dtype: Optional[torch.dtype] = None,
789+
quant_config: Optional[QuantizationConfig] = None,
790+
shard_ids: List[int] = [0],
791+
total_num_shards: int = 1,):
792+
self.output_sizes = output_sizes
793+
# tp_size = get_tensor_model_parallel_world_size()
794+
# assert all(output_size % tp_size == 0 for output_size in output_sizes)
795+
super().__init__(input_size=input_size,
796+
output_size=sum(output_sizes),
797+
bias=bias,
798+
gather_output=gather_output,
799+
skip_bias_add=skip_bias_add,
800+
params_dtype=params_dtype,
801+
quant_config=quant_config,
802+
shard_ids=shard_ids,
803+
total_num_shards=total_num_shards,
804+
param_class=GateUpShardedParameter,
805+
)
806+
807+
def weight_loader(self,
808+
param: Parameter,
809+
loaded_weight: torch.Tensor,
810+
loaded_shard_id: Optional[int] = None):
811+
812+
param_data = param.data
813+
output_dim = getattr(param, "output_dim", None)
814+
# Special case for AQLM codebooks.
815+
is_metadata = getattr(param, "is_metadata", False)
816+
817+
param_shard_splitter = getattr(param, "shard_splitter", None)
818+
819+
if output_dim is not None and param_shard_splitter is not None:
820+
raise NotImplementedError(
821+
"We do not currently support output_dim != None and "
822+
"shard_splitter != None for a parameter. Please open an issue."
823+
)
824+
# If a parameter has defined a shard_splitter to be used for
825+
# the weight, it should be applied before the weight is
826+
# loaded/copied to the parameter. The shard_splitter applies
827+
# logic by using the loaded_shard_id to ensure that the loaded
828+
# param is loaded to the correct location
829+
# within the parameter defined by the linear method.
830+
if loaded_shard_id is None and param_shard_splitter is not None:
831+
raise NotImplementedError(
832+
"We do not currently support loaded_shard_id == None and "
833+
"shard_splitter != None for a parameter. Please open an issue."
834+
)
835+
836+
# Special case for Fp8 scales.
837+
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
838+
None)
839+
840+
if loaded_shard_id is None:
841+
# Loaded weight is already packed.
842+
if output_dim is None:
843+
assert param_data.shape == loaded_weight.shape
844+
param_data.copy_(loaded_weight)
845+
return
846+
current_shard_offset = 0
847+
shard_offsets = []
848+
for i, output_size in enumerate(self.output_sizes):
849+
shard_offsets.append((i, current_shard_offset, output_size))
850+
current_shard_offset += output_size
851+
packed_dim = getattr(param, "packed_dim", None)
852+
for shard_id, shard_offset, shard_size in shard_offsets:
853+
# Special case for Quantization.
854+
# If quantized, we need to adjust the offset and size to account
855+
# for the packing.
856+
if packed_dim == output_dim:
857+
shard_size = shard_size // param.pack_factor
858+
shard_offset = shard_offset // param.pack_factor
859+
# Special case for Marlin.
860+
shard_size, shard_offset = adjust_marlin_shard(
861+
param, shard_size, shard_offset)
862+
863+
loaded_weight_shard = loaded_weight.narrow(
864+
output_dim, shard_offset, shard_size)
865+
self.weight_loader(param, loaded_weight_shard, shard_id)
866+
return
867+
868+
assert loaded_shard_id < len(self.output_sizes)
869+
tp_rank = get_tensor_model_parallel_rank()
870+
tp_size = get_tensor_model_parallel_world_size()
871+
if output_dim is not None:
872+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
873+
shard_size = self.output_sizes[loaded_shard_id] // tp_size
874+
# Special case for quantization.
875+
# If quantized, we need to adjust the offset and size to account
876+
# for the packing.
877+
packed_dim = getattr(param, "packed_dim", None)
878+
if packed_dim == output_dim:
879+
shard_size = shard_size // param.pack_factor
880+
shard_offset = shard_offset // param.pack_factor
881+
# Special case for Marlin.
882+
shard_size, shard_offset = adjust_marlin_shard(
883+
param, shard_size, shard_offset)
884+
885+
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
886+
if use_bitsandbytes:
887+
shard_size = loaded_weight.shape[output_dim]
888+
shard_offset = loaded_weight.shape[output_dim] * \
889+
loaded_shard_id
890+
891+
param_data = param_data.narrow(output_dim, shard_offset,
892+
shard_size)
893+
start_idx = tp_rank * shard_size
894+
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
895+
shard_size)
896+
# Special case for AQLM codebooks.
897+
elif is_metadata:
898+
# metadata indicates fixed size concatenated along dim 0
899+
shard_size = loaded_weight.shape[0]
900+
shard_offset = loaded_shard_id * shard_size
901+
param_data = param_data.narrow(0, shard_offset, shard_size)
902+
903+
# If a param_shard_splitter is defined by the LinearMethod, use it.
904+
elif param_shard_splitter is not None:
905+
logical_widths = getattr(param, "logical_widths", None)
906+
param_data, loaded_weight = param_shard_splitter(
907+
param_data, loaded_weight, loaded_shard_id, logical_widths)
908+
909+
# Special case for Fp8 scales.
910+
elif fp8_scales_shard_indexer is not None:
911+
param_data, loaded_weight = fp8_scales_shard_indexer(
912+
param_data, loaded_weight, loaded_shard_id)
913+
914+
else:
915+
ignore_warning = getattr(param, "ignore_warning", False)
916+
if not ignore_warning:
917+
logger.warning(
918+
"Loading a weight without `output_dim` attribute in "
919+
"MergedColumnParallelLinear, assume the weight is "
920+
"the same for all partitions.")
921+
922+
if fp8_scales_shard_indexer is None:
923+
if len(param_data.shape) == 0:
924+
param_data = param_data.reshape(1)
925+
926+
if len(loaded_weight.shape) == 0:
927+
loaded_weight = loaded_weight.reshape(1)
928+
929+
assert param_data.shape == loaded_weight.shape
930+
param_data.copy_(loaded_weight)

vllm/liquid/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,11 @@ def __init__(self,
394394
bias: bool = False,
395395
params_dtype: Optional[torch.dtype] = None,
396396
org_num_embeddings: Optional[int] = None,
397-
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
397+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
398+
shard_ids: List[int] = [0],
399+
total_num_shards: int = 1,):
398400
super().__init__(num_embeddings, embedding_dim, params_dtype,
399-
org_num_embeddings, padding_size)
401+
org_num_embeddings, padding_size, shard_ids, total_num_shards)
400402
if bias:
401403
self.bias = Parameter(
402404
torch.empty(self.num_embeddings_per_partition,

0 commit comments

Comments
 (0)