-
Notifications
You must be signed in to change notification settings - Fork 0
Add ServerlessLLM Support #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1817c84
dccd493
aee2969
a478560
ff2b521
c8fda58
fc6a6e7
89fb6d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Saves each worker's model state dict directly to a checkpoint, which enables a | ||
fast load path for large tensor-parallel models where each worker only needs to | ||
read its own shard rather than the entire checkpoint. | ||
|
||
Example usage: | ||
|
||
python save_sharded_state.py \ | ||
--model /path/to/load \ | ||
--quantization deepspeedfp \ | ||
--tensor-parallel-size 8 \ | ||
--output /path/to/save | ||
|
||
Then, the model can be loaded with | ||
|
||
llm = LLM( | ||
model="/path/to/save", | ||
load_format="sharded_state", | ||
quantization="deepspeedfp", | ||
tensor_parallel_size=8, | ||
) | ||
""" | ||
import argparse | ||
import dataclasses | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
from vllm import LLM, EngineArgs | ||
|
||
parser = argparse.ArgumentParser() | ||
EngineArgs.add_cli_args(parser) | ||
parser.add_argument("--output", | ||
"-o", | ||
required=True, | ||
type=str, | ||
help="path to output checkpoint") | ||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
# main(args) | ||
|
||
llm = LLM( | ||
model=args.output, | ||
load_format="serverless_llm", | ||
# load_format="sharded_state", | ||
gpu_memory_utilization=0.9, | ||
distributed_executor_backend="mp", | ||
max_model_len = 512, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
# num_gpu_blocks_override=128, | ||
) | ||
|
||
input_text = "Explain thread and process in python." | ||
|
||
print(llm.generate(input_text)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
CUDA_VISIBLE_DEVICES=0,1 python save_sllm_state.py \ | ||
--model /mnt/raid0sata1/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6 \ | ||
--tensor-parallel-size 4 \ | ||
--output /mnt/raid0nvme1/xly/test_data/vllm/opt-125m | ||
|
||
CUDA_VISIBLE_DEVICES=0,1 python load_sllm_state.py \ | ||
--model /home/fuji/.cache/huggingface/hub/models--facebook--opt-1.3b/snapshots/3f5c25d0bc631cb57ac65913f76e22c2dfb61d62 \ | ||
--tensor-parallel-size 2 \ | ||
--output /home/fuji/sllm_models/opt-1.3b |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Saves each worker's model state dict directly to a checkpoint, which enables a | ||
fast load path for large tensor-parallel models where each worker only needs to | ||
read its own shard rather than the entire checkpoint. | ||
|
||
Example usage: | ||
|
||
python save_sharded_state.py \ | ||
--model /path/to/load \ | ||
--quantization deepspeedfp \ | ||
--tensor-parallel-size 8 \ | ||
--output /path/to/save | ||
|
||
Then, the model can be loaded with | ||
|
||
llm = LLM( | ||
model="/path/to/save", | ||
load_format="sharded_state", | ||
quantization="deepspeedfp", | ||
tensor_parallel_size=8, | ||
) | ||
""" | ||
import argparse | ||
import dataclasses | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
from vllm import LLM, EngineArgs | ||
|
||
parser = argparse.ArgumentParser() | ||
EngineArgs.add_cli_args(parser) | ||
parser.add_argument("--output", | ||
"-o", | ||
required=True, | ||
type=str, | ||
help="path to output checkpoint") | ||
parser.add_argument("--file-pattern", | ||
type=str, | ||
help="string pattern of saved filenames") | ||
parser.add_argument("--max-file-size", | ||
type=str, | ||
default=5 * 1024**3, | ||
help="max size (in bytes) of each safetensors file") | ||
|
||
|
||
def main(args): | ||
engine_args = EngineArgs.from_cli_args(args) | ||
engine_args.distributed_executor_backend = "mp" | ||
engine_args.gpu_memory_utilization = 0.4 | ||
engine_args.max_seq_len_to_capture = 512 | ||
engine_args.max_model_len = 512 | ||
engine_args.max_num_seqs = 1 | ||
engine_args.num_gpu_blocks_override = 128 | ||
if engine_args.enable_lora: | ||
raise ValueError("Saving with enable_lora=True is not supported!") | ||
model_path = engine_args.model | ||
if not Path(model_path).is_dir(): | ||
raise ValueError("model path must be a local directory") | ||
# Create LLM instance from arguments | ||
print(dataclasses.asdict(engine_args)) | ||
llm = LLM(**dataclasses.asdict(engine_args)) | ||
# Prepare output directory | ||
Path(args.output).mkdir(exist_ok=True) | ||
# Dump worker states to output directory | ||
model_executor = llm.llm_engine.model_executor | ||
model_executor.save_serverless_llm_state(path=args.output, | ||
pattern=args.file_pattern, | ||
max_size=args.max_file_size) | ||
# Copy metadata files to output directory | ||
for file in os.listdir(model_path): | ||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): | ||
if os.path.isdir(os.path.join(model_path, file)): | ||
shutil.copytree(os.path.join(model_path, file), | ||
os.path.join(args.output, file)) | ||
else: | ||
shutil.copy(os.path.join(model_path, file), args.output) | ||
|
||
from vllm.distributed import get_tensor_model_parallel_rank | ||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
main(args) | ||
|
||
# llm = LLM( | ||
# model=args.output, | ||
# load_format="serverless_llm", | ||
# tensor_parallel_size=2, | ||
# ) | ||
|
||
# input_text = "Hello, world!" | ||
|
||
# print(llm.generate(input_text)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import torch | ||
from huggingface_hub import HfApi, hf_hub_download | ||
from torch import nn | ||
import gc | ||
|
||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, | ||
LoRAConfig, ModelConfig, ParallelConfig, | ||
|
@@ -418,7 +419,6 @@ def save_model( | |
tensorizer_config=tensorizer_config, | ||
) | ||
|
||
|
||
class ShardedStateLoader(BaseModelLoader): | ||
""" | ||
Model loader that directly loads each worker's model state dict, which | ||
|
@@ -576,6 +576,128 @@ def save_model( | |
os.path.join(path, filename), | ||
) | ||
|
||
class ServerlessLLMLoader(BaseModelLoader): | ||
# DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" | ||
|
||
def __init__(self, load_config: LoadConfig): | ||
super().__init__(load_config) | ||
extra_config = ({} if load_config.model_loader_extra_config is None | ||
else load_config.model_loader_extra_config.copy()) | ||
# self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) | ||
if extra_config: | ||
raise ValueError(f"Unexpected extra config keys for load format " | ||
f"{load_config.load_format}: " | ||
f"{load_config.model_loader_extra_config.keys()}") | ||
|
||
@staticmethod | ||
def _filter_subtensors( | ||
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
""" | ||
Filter out all tensors that share the same memory or a subset of the | ||
memory of another tensor. | ||
""" | ||
same_storage_groups = collections.defaultdict(list) | ||
for key, tensor in tensors.items(): | ||
if tensor.numel(): | ||
ptr = tensor.untyped_storage().data_ptr() | ||
same_storage_groups[tensor.device, ptr].append((key, tensor)) | ||
|
||
def get_end_ptr(tensor: torch.Tensor) -> int: | ||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size() | ||
|
||
result = {} | ||
for group in same_storage_groups.values(): | ||
for k, t in group: | ||
a, b = t.data_ptr(), get_end_ptr(t) | ||
for k2, t2 in group: | ||
if not t2.is_contiguous(): | ||
continue | ||
a2, b2 = t2.data_ptr(), get_end_ptr(t2) | ||
if a < a2 or b2 < b: | ||
continue | ||
if a2 < a or b < b2 or not t.is_contiguous(): | ||
break # t2 covers strictly more memory than t. | ||
if k2 < k: | ||
# Same tensors, keep the one with the smaller key. | ||
break | ||
else: | ||
result[k] = t | ||
return result | ||
|
||
|
||
def load_model(self, *, model_config: ModelConfig, | ||
device_config: DeviceConfig, | ||
lora_config: Optional[LoRAConfig], | ||
vision_language_config: Optional[VisionLanguageConfig], | ||
parallel_config: ParallelConfig, | ||
scheduler_config: SchedulerConfig, | ||
cache_config: CacheConfig) -> nn.Module: | ||
from serverless_llm_store import load_dict_single_device | ||
from vllm.distributed import get_tensor_model_parallel_rank | ||
|
||
assert os.path.isdir(model_config.model) | ||
|
||
rank = get_tensor_model_parallel_rank() | ||
|
||
local_model_path = model_config.model | ||
local_model_path = os.path.join(local_model_path, f"rank_{rank}") | ||
|
||
# model name is everything after models | ||
model_name = local_model_path.split("models/")[1] | ||
storage_path = local_model_path.split("models/")[0] | ||
if storage_path.endswith("/"): | ||
storage_path = os.path.join(storage_path, "models") | ||
else: | ||
storage_path = storage_path + "models" | ||
|
||
with set_default_torch_dtype(model_config.dtype): | ||
# with torch.device(device_config.device): | ||
with torch.device("cpu"): | ||
model = _initialize_model(model_config, self.load_config, | ||
lora_config, vision_language_config, | ||
cache_config) | ||
model = model.eval() | ||
# set all parameters to meta device | ||
state_dict = self._filter_subtensors(model.state_dict()) | ||
key_list = list(state_dict.keys()) | ||
|
||
for key, param in model.named_parameters(recurse=True): | ||
if key in key_list: | ||
param.data = torch.empty(1, device="cuda") | ||
gc.collect() | ||
|
||
sllm_state_dict = load_dict_single_device(model_name, storage_path) | ||
|
||
for key, param in model.named_parameters(recurse=True): | ||
if key in key_list: | ||
tensor = sllm_state_dict[key] | ||
param.data = tensor | ||
state_dict.pop(key) | ||
if state_dict: | ||
raise ValueError( | ||
f"Missing keys {tuple(state_dict)} in loaded state!") | ||
|
||
return model | ||
|
||
@staticmethod | ||
def save_model( | ||
model: torch.nn.Module, | ||
path: str, | ||
pattern: Optional[str] = None, | ||
max_size: Optional[int] = None, | ||
) -> None: | ||
from vllm.distributed import get_tensor_model_parallel_rank | ||
from serverless_llm_store import save_dict | ||
|
||
rank = get_tensor_model_parallel_rank() | ||
state_dict = ServerlessLLMLoader._filter_subtensors(model.state_dict()) | ||
|
||
# move all tensors to CPU | ||
for key, tensor in state_dict.items(): | ||
state_dict[key] = tensor.cpu().contiguous() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need to add
here or it may have failed to open file error
|
||
save_dict(state_dict, os.path.join(path, f"rank_{rank}")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't found the save dict function under serverless_llm_store, where can I build the latest version with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use latest xly/fix-docker-build from serverlessllm |
||
|
||
|
||
class BitsAndBytesModelLoader(BaseModelLoader): | ||
"""Model loader to load model weights with BitAndBytes quantization.""" | ||
|
@@ -826,6 +948,9 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: | |
|
||
if load_config.load_format == LoadFormat.SHARDED_STATE: | ||
return ShardedStateLoader(load_config) | ||
|
||
if load_config.load_format == LoadFormat.SERVERLESS_LLM: | ||
return ServerlessLLMLoader(load_config) | ||
|
||
if load_config.load_format == LoadFormat.BITSANDBYTES: | ||
return BitsAndBytesModelLoader(load_config) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this create another copy of model parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy exactly from vllm implementation, we inherit the save behaviour