Skip to content

Commit

Permalink
[Frontend] [Core] Support for sharded tensorized models (vllm-projec…
Browse files Browse the repository at this point in the history
…t#4990)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Sanger Steel <sangersteel@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
3 people authored and jimpang committed Jun 27, 2024
1 parent 99c2f33 commit 30adfe7
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 107 deletions.
125 changes: 60 additions & 65 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,12 @@
import json
import os
import uuid
from functools import partial

from tensorizer import stream_io

from vllm import LLM
from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
serialize_vllm_model)
tensorize_vllm_model)

# yapf conflicts with isort for this docstring
# yapf: disable
Expand Down Expand Up @@ -61,6 +55,12 @@
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
To support distributed tensor-parallel models, each model shard will be
serialized to a separate file. The tensorizer_uri is then specified as a string
template with a format specifier such as '%03d' that will be rendered with the
shard's rank. Sharded models serialized with this script will be named as
model-rank-%03d.tensors
For more information on the available arguments for serializing, run
`python -m examples.tensorize_vllm_model serialize --help`.
Expand Down Expand Up @@ -168,77 +168,72 @@ def parse_args():
def deserialize():
llm = LLM(model=args.model,
load_format="tensorizer",
tensor_parallel_size=args.tensor_parallel_size,
model_loader_extra_config=tensorizer_config
)
return llm


if __name__ == '__main__':
args = parse_args()

args = parse_args()

s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}
s3_access_key_id = (getattr(args, 's3_access_key_id', None)
or os.environ.get("S3_ACCESS_KEY_ID", None))
s3_secret_access_key = (getattr(args, 's3_secret_access_key', None)
or os.environ.get("S3_SECRET_ACCESS_KEY", None))
s3_endpoint = (getattr(args, 's3_endpoint', None)
or os.environ.get("S3_ENDPOINT_URL", None))

_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
credentials = {
"s3_access_key_id": s3_access_key_id,
"s3_secret_access_key": s3_secret_access_key,
"s3_endpoint": s3_endpoint
}

model_ref = args.model
model_ref = args.model

model_name = model_ref.split("/")[1]
model_name = model_ref.split("/")[1]

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
keyfile = args.keyfile if args.keyfile else None

init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel()
if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = \
TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None

keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(
argparse.Namespace(**eng_args_dict)
)

if args.model_loader_extra_config:
config = json.loads(args.model_loader_extra_config)
tensorizer_args = TensorizerConfig(**config)._construct_tensorizer_args()
tensorizer_args.tensorizer_uri = args.path_to_tensors
else:
tensorizer_args = None

if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}

engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
if engine_args.tensor_parallel_size > 1:
model_path = f"{base_path}/model-rank-%03d.tensors"
else:
model_path = f"{base_path}/model.tensors"

input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
**credentials)
serialize_vllm_model(engine, tensorizer_config, keyfile)
elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
tensorizer_uri=model_path,
encryption_keyfile=keyfile,
**credentials)

tensorize_vllm_model(engine_args, tensorizer_config)

elif args.command == "deserialize":
if not tensorizer_args:
tensorizer_config = TensorizerConfig(
tensorizer_uri=args.path_to_tensors,
encryption_keyfile = keyfile,
**credentials
)
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")
99 changes: 90 additions & 9 deletions tests/tensorizer_loader/test_tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import json
import os
import pathlib
import subprocess
from unittest.mock import MagicMock, patch

import openai
import pytest
import ray
import torch
from tensorizer import EncryptionParams

from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
# yapf: disable
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
TensorSerializer,
is_vllm_tensorized,
load_with_tensorizer,
open_stream,
serialize_vllm_model)
serialize_vllm_model,
tensorize_vllm_model)

from ..conftest import VllmRunner, cleanup
from ..utils import ServerRunner

# yapf conflicts with isort for this docstring
Expand All @@ -42,6 +48,20 @@ def is_curl_installed():
except (subprocess.CalledProcessError, FileNotFoundError):
return False

def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model

def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key)

@pytest.fixture(autouse=True)
def tensorizer_config():
Expand Down Expand Up @@ -88,12 +108,17 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
write_keyfile(key_path)

outputs = vllm_model.generate(prompts, sampling_params)

config_for_serializing = TensorizerConfig(tensorizer_uri=model_path)
serialize_vllm_model(vllm_model.model.llm_engine,
config_for_serializing,
encryption_key_path=key_path)
config_for_serializing = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path
)
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)


config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
Expand Down Expand Up @@ -145,7 +170,7 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")

serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))

with vllm_runner(
Expand Down Expand Up @@ -180,7 +205,7 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")

serialize_vllm_model(vllm_model.model.llm_engine,
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))

model_loader_extra_config = {
Expand Down Expand Up @@ -224,7 +249,9 @@ def test_raise_value_error_on_invalid_load_format(vllm_runner):
model_loader_extra_config=TensorizerConfig(tensorizer_uri="test"))


def test_tensorizer_with_tp(vllm_runner):
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_tensorizer_with_tp_path_without_template(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
Expand All @@ -238,8 +265,62 @@ def test_tensorizer_with_tp(vllm_runner):
s3_endpoint="object.ord1.coreweave.com",
),
tensor_parallel_size=2,
disable_custom_all_reduce=True,
)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Requires 2 GPUs")
def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
tmp_path):
model_ref = "EleutherAI/pythia-1.4b"
# record outputs from un-sharded un-tensorized model
base_model = vllm_runner(
model_ref,
disable_custom_all_reduce=True,
enforce_eager=True,
)
outputs = base_model.generate(prompts, sampling_params)

base_model.model.llm_engine.model_executor.shutdown()
del base_model
cleanup()
ray.shutdown()

# load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
key_path = tmp_path / (model_ref + ".key")

tensorizer_config = TensorizerConfig(
tensorizer_uri=model_path,
encryption_keyfile=key_path,
)

tensorize_vllm_model(
engine_args=EngineArgs(
model=model_ref,
tensor_parallel_size=2,
disable_custom_all_reduce=True,
enforce_eager=True,
),
tensorizer_config=tensorizer_config,
)
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup()
ray.shutdown()

loaded_vllm_model = vllm_runner(
model_ref,
tensor_parallel_size=2,
load_format="tensorizer",
disable_custom_all_reduce=True,
enforce_eager=True,
model_loader_extra_config=tensorizer_config)

deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)

assert outputs == deserialized_outputs


def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
model_ref = "facebook/opt-125m"
Expand All @@ -248,7 +329,7 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):

with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(vllm_model.model.llm_engine, config)
serialize_vllm_model(get_torch_model(vllm_model), config)

assert is_vllm_tensorized(config)

Expand Down
18 changes: 17 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
tensorizer_weights_iterator)
serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -392,6 +392,12 @@ def load_model(self, *, model_config: ModelConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config)

if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \
% get_tensor_model_parallel_rank()

if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config,
Expand All @@ -402,6 +408,16 @@ def load_model(self, *, model_config: ModelConfig,
vision_language_config,
cache_config)

@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
) -> None:
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)


class ShardedStateLoader(BaseModelLoader):
"""
Expand Down
Loading

0 comments on commit 30adfe7

Please sign in to comment.