-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Frontend] [Core] feat: Add model loading using tensorizer
#3476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ywang96
merged 102 commits into
vllm-project:main
from
coreweave:sangstar/integrate-tensorizer
Apr 14, 2024
Merged
Changes from all commits
Commits
Show all changes
102 commits
Select commit
Hold shift + click to select a range
dfe2f2f
feat: Support loading model tensors using `tensorizer`
sangstar 097f297
fix: Remove unnecessary files
sangstar 24e8657
fix(vllm-tensorizer): Allow providing S3 credentials
sangstar 6192ff3
fix: Fix passing S3 auth vars through stream
sangstar fbc847b
fix: Disallowing `plaid_mode = False` and updating `tensorizer` version
sangstar f4d57d8
refactor: Retire use of `download_dir` as `TensorizerArgs` param
sangstar cf42149
fix: Remove `store_true` action for `--tensorizer-uri`
sangstar c1839f4
refactor: No 2x copying for `tensorizer` (WIP)
sangstar b28b26e
chore: Omit commandeering weight loaders for merging layers (WIP)
sangstar fad72a4
feat: Re-add deserializing vLLM models
sangstar 8d421b4
chore: Harmonize CPU and GPU deserializing
sangstar 8225c32
perf: Add `force_http=True` for faster loading speeds
sangstar f7c9cc7
chore: Reformat code with `format.sh`, cleanup debugging code
sangstar 44b05ba
chore: Fix formatting, some misc. changes
sangstar 17977b0
fix: Correct logging for loading tensorizer with cpu
sangstar 68f2a51
chore: Implement changes from feedback
sangstar 0c72c2c
fix: Correctly instantiate vLLM-formatted models
sangstar af10594
chore: Reformat and delete deprecated comment from `.ipynb`
sangstar 550983a
perf: Allow passing of deserializer args from `TensorizerArgs`
sangstar 6273266
style: Reformat with new formatting changes
sangstar f6a695b
Run yapf and ruff
sangstar f30f4e0
fix: Fix incorrect `TensorizerArgs` import in `config.py`
sangstar c539880
perf: Multiple misc. improvements from code review
sangstar 1632381
pref: More misc. fixes to complete initial code review
sangstar 4085cb5
fix: Remove `print(tensorizer_args)`
sangstar 81a752a
Run yapf and ruff
sangstar aa8d8b4
fix: Add specific category for warnings with `PerformanceWarning`
sangstar d8e71df
chore: Multiple fixes from final code review
sangstar 5132dd7
fix: Add `s3_endpoint` as attr for `TensorizerArgs`
sangstar 7dd43f5
chore: Remove `filter_func` from CLI args, some doc fixes
sangstar ad68ff5
chore: Allow env var or CLI arg specification for S3 credentials
sangstar f965730
fix: Disallow using `force_http`
sangstar 2605a33
chore: Remove unnecessary print statement in example script
sangstar 71c2cb0
Run yapf and ruff
sangstar 35b29e8
Run yapf and ruff
sangstar 117feec
Run yapf and ruff
sangstar 6192e9d
docs: Update `tensorizer` as a `--load-format` in `engine_args.rst`
sangstar 407b32e
fix: Restore `tensorizer_args` as instance attr to `EngineArgs`
sangstar 88e209d
Run yapf and ruff
sangstar 6e23dcd
chore: Move testing out of own test folder
sangstar 05c0bbe
fix: Add `tensorizer >= 2.8.1` to `requirements-rocm.txt` for CI
sangstar af11a53
fix: Add version of `tensorizer` that will pass testing suite
sangstar 8ece4f8
chore: Add notice that `requirements-dev` dep can be removed `>2.8.1`
sangstar d4a46a5
fix: Resolve double `HfFileSystem` import
sangstar 12b1f12
style: Run `isort`
sangstar 445ab28
Run yapf and ruff
sangstar 6c286ed
fix: Add `tensorizer` to mock imports
sangstar 37348f9
perf: Add newest `tensorizer` version that will not init CUDA
sangstar 82da7a5
fix: Adjust `tensorizer` version for `requirements-dev.txt`
sangstar 310dd68
chore: Rebase and fix carrying over changes to `arg_utils` typing
sangstar cf56513
fix: Add `tensorizer` to `requirements-cpu.txt`
sangstar 9c8db87
perf: Add concurrent reading to `TensorDeserializer`
sangstar 8ca0cb1
docs: Add `num_readers` docstring
sangstar 21bca06
chore: Replace `PerformanceWarning` after rebase
sangstar 0c82446
Run yapf and ruff
sangstar 06cd26d
fix: Fix model output on deserialization and add e2e output test
sangstar f19ee64
fix: Properly ensure test outputs are deterministic, add HF model test
sangstar f1f2e16
fix: Make vLLM tensorizing specification less hacky
sangstar 71a9f79
docs: Add tensorizer link in `engine_args.rst`, docstring to example
sangstar 9e5456a
chore: Resolve comments
sangstar 74a8642
fix: Affirm mandatory `vllm_tensorized` argument change
sangstar f82b25a
perf: Allow preliminary support deserializing with LoRA adapters
sangstar 3ec85e0
fix: Fix requirements.txt passing import tensorizer only if installed
sangstar dfb7a11
fix: Properly ensure import fail if tensorizer not used nor installed
sangstar 81196ed
perf: Move test location and add testing for LoRA
sangstar 5ecf4ee
perf: Add some testing changes, introduce `TensorizerConfig`
sangstar b267cbd
chore: Add `__init__.py` for `tests/tensorizer`
sangstar 6bff0c7
tests: Fix `test_tensorizer.py` to account for new changes
sangstar e0b7184
tests: Remove `test_tensorizer_api_server.py`
sangstar 3ec105d
Run yapf and ruff; fix tests
sangstar 9d568fc
fix: Revert change to `examples/multilora_inference.py`
sangstar 55d2a41
Merge remote-tracking branch 'upstream/main' into sangstar/integrate-…
sangstar 65bc7bb
Merge remote-tracking branch 'upstream/main' into sangstar/integrate-…
sangstar b1b5653
perf: Update code to reflect change in #3977
sangstar 5f27722
chore: Remove accidental syntax error
sangstar 8240af9
docs: Elaborate on S3 credentialing
sangstar 7f5eada
fix: Properly passing `tensorizer_config` to hf weight loader
sangstar e0d9cc7
fix: Fix, test tensorizer uri passing without tensorizer load format
sangstar de54538
docs: Note example script in docs for more information
sangstar 1feab4e
chore: Run yapf and ruff, as well as doc edits
sangstar a9b0241
fix: Fix `initialize_model_parallel` import
sangstar a297a62
tests: Add test for `examples/tensorize_vllm_model.py`
sangstar 2d07568
tests: Fix lora test
sangstar 1bddfe6
Run yapf and ruff
sangstar 852f0ad
fix: Move `tensorize_loader` imports to pass CPU test
sangstar aef7442
refactor: Pass `TensorizerArgs` direct to `EngineArgs.add_cli_args`
sangstar ff0a528
tests: Add api_server test using tensorizer
sangstar 4551b84
fix: Add `tensorizer_config` to `RayGPUExecutor`
sangstar d51b0bc
tests: Formatting and add test to ensure `tensorizer` load format
sangstar 64178e4
style: Run yapf on `examples/tensorize_vllm_model.py`
sangstar 2f4dcb3
style: Run isort on `examples/tensorize_vllm_model.py`
sangstar 3df1945
style: Fix yapf and isort conflict
sangstar eb925f0
fix: Remove `tensorizer_args` from `ModelConfig`
sangstar ba6927d
fix: Add error for device scattering and initial handling for quant
sangstar bd461cc
perf: Multiple changes in response to comments
sangstar ca2a3fb
perf: Final changes to resolve comments
sangstar 428f53d
fix: Skip tests if cURL not installed, add example script for testing
sangstar 88f1a67
Run yapf and ruff
sangstar d2491ac
tests: Install cURL for tensorizer tests for testing suite
sangstar d77215f
tests: Install libsodium23 for CI tensorizer tests
sangstar 9de338c
fix: Fix testing import path
sangstar 95251d7
Run yapf and ruff
sangstar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,7 @@ | |
"vllm._C", | ||
"numpy", | ||
"tqdm", | ||
"tensorizer", | ||
] | ||
|
||
for mock_target in autodoc_mock_imports: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
import argparse | ||
import dataclasses | ||
import os | ||
import time | ||
import uuid | ||
from functools import partial | ||
from typing import Type | ||
|
||
import torch | ||
import torch.nn as nn | ||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, | ||
TensorSerializer, stream_io) | ||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor | ||
from transformers import AutoConfig, PretrainedConfig | ||
|
||
from vllm.distributed import initialize_model_parallel | ||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.engine.llm_engine import LLMEngine | ||
from vllm.model_executor.models import ModelRegistry | ||
from vllm.model_executor.tensorizer_loader import TensorizerArgs | ||
|
||
# yapf conflicts with isort for this docstring | ||
# yapf: disable | ||
""" | ||
tensorize_vllm_model.py is a script that can be used to serialize and | ||
deserialize vLLM models. These models can be loaded using tensorizer directly | ||
to the GPU extremely quickly. Tensor encryption and decryption is also | ||
supported, although libsodium must be installed to use it. Install | ||
vllm with tensorizer support using `pip install vllm[tensorizer]`. | ||
|
||
To serialize a model, you can run something like this: | ||
|
||
python tensorize_vllm_model.py \ | ||
--model EleutherAI/gpt-j-6B \ | ||
--dtype float16 \ | ||
serialize \ | ||
--serialized-directory s3://my-bucket/ \ | ||
--suffix vllm | ||
|
||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it, | ||
and saves it to your S3 bucket. A local directory can also be used. | ||
|
||
You can also encrypt the model weights with a randomly-generated key by | ||
providing a `--keyfile` argument. | ||
|
||
To deserialize a model, you can run something like this: | ||
|
||
python tensorize_vllm_model.py \ | ||
--model EleutherAI/gpt-j-6B \ | ||
--dtype float16 \ | ||
deserialize \ | ||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors | ||
|
||
Which downloads the model tensors from your S3 bucket and deserializes them. | ||
sangstar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
To provide S3 credentials, you can provide `--s3-access-key-id` and | ||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, | ||
the OpenAI entrypoint, as arguments for LLM(), or as environment variables | ||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. | ||
|
||
|
||
You can also provide a `--keyfile` argument to decrypt the model weights if | ||
they were serialized with encryption. | ||
|
||
For more information on the available arguments, run | ||
`python tensorize_vllm_model.py --help`. | ||
""" | ||
|
||
|
||
sangstar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def parse_args(): | ||
sangstar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser = argparse.ArgumentParser( | ||
description="An example script that can be used to serialize and " | ||
"deserialize vLLM models. These models " | ||
"can be loaded using tensorizer directly to the GPU " | ||
"extremely quickly. Tensor encryption and decryption is " | ||
"also supported, although libsodium must be installed to " | ||
"use it.") | ||
parser = EngineArgs.add_cli_args(parser) | ||
subparsers = parser.add_subparsers(dest='command') | ||
|
||
serialize_parser = subparsers.add_parser( | ||
'serialize', help="Serialize a model to `--serialized-directory`") | ||
|
||
serialize_parser.add_argument( | ||
"--suffix", | ||
type=str, | ||
required=False, | ||
help=( | ||
"The suffix to append to the serialized model directory, which is " | ||
"used to construct the location of the serialized model tensors, " | ||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and " | ||
"`--suffix` is `v1`, the serialized model tensors will be " | ||
"saved to " | ||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " | ||
"If none is provided, a random UUID will be used.")) | ||
serialize_parser.add_argument( | ||
"--serialized-directory", | ||
type=str, | ||
required=True, | ||
help="The directory to serialize the model to. " | ||
"This can be a local directory or S3 URI. The path to where the " | ||
"tensors are saved is a combination of the supplied `dir` and model " | ||
"reference ID. For instance, if `dir` is the serialized directory, " | ||
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " | ||
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " | ||
"where `suffix` is given by `--suffix` or a random UUID if not " | ||
"provided.") | ||
|
||
serialize_parser.add_argument( | ||
"--keyfile", | ||
type=str, | ||
required=False, | ||
help=("Encrypt the model weights with a randomly-generated binary key," | ||
" and save the key at this path")) | ||
|
||
deserialize_parser = subparsers.add_parser( | ||
'deserialize', | ||
help=("Deserialize a model from `--path-to-tensors`" | ||
" to verify it can be loaded and used.")) | ||
|
||
deserialize_parser.add_argument( | ||
"--path-to-tensors", | ||
type=str, | ||
required=True, | ||
help="The local path or S3 URI to the model tensors to deserialize. ") | ||
|
||
deserialize_parser.add_argument( | ||
"--keyfile", | ||
type=str, | ||
required=False, | ||
help=("Path to a binary key to use to decrypt the model weights," | ||
" if the model was serialized with encryption")) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def make_model_contiguous(model): | ||
# Ensure tensors are saved in memory contiguously | ||
for param in model.parameters(): | ||
param.data = param.data.contiguous() | ||
|
||
|
||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: | ||
architectures = getattr(config, "architectures", []) | ||
for arch in architectures: | ||
model_cls = ModelRegistry.load_model_cls(arch) | ||
if model_cls is not None: | ||
return model_cls | ||
raise ValueError( | ||
f"Model architectures {architectures} are not supported for now. " | ||
f"Supported architectures: {ModelRegistry.get_supported_archs()}") | ||
|
||
|
||
def serialize(): | ||
|
||
eng_args_dict = {f.name: getattr(args, f.name) for f in | ||
dataclasses.fields(EngineArgs)} | ||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) | ||
engine = LLMEngine.from_engine_args(engine_args) | ||
|
||
model = (engine.model_executor.driver_worker. | ||
model_runner.model) | ||
|
||
encryption_params = EncryptionParams.random() if keyfile else None | ||
if keyfile: | ||
with _write_stream(keyfile) as stream: | ||
stream.write(encryption_params.key) | ||
|
||
with _write_stream(model_path) as stream: | ||
serializer = TensorSerializer(stream, encryption=encryption_params) | ||
serializer.write_module(model) | ||
serializer.close() | ||
|
||
print("Serialization complete. Model tensors saved to", model_path) | ||
if keyfile: | ||
print("Key saved to", keyfile) | ||
|
||
|
||
def deserialize(): | ||
config = AutoConfig.from_pretrained(model_ref) | ||
|
||
with no_init_or_tensor(): | ||
model_class = _get_vllm_model_architecture(config) | ||
model = model_class(config) | ||
|
||
before_mem = get_mem_usage() | ||
start = time.time() | ||
|
||
if keyfile: | ||
with _read_stream(keyfile) as stream: | ||
key = stream.read() | ||
decryption_params = DecryptionParams.from_key(key) | ||
tensorizer_args.deserializer_params['encryption'] = \ | ||
decryption_params | ||
|
||
with (_read_stream(model_path)) as stream, TensorDeserializer( | ||
stream, **tensorizer_args.deserializer_params) as deserializer: | ||
deserializer.load_into_module(model) | ||
end = time.time() | ||
|
||
# Brag about how fast we are. | ||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) | ||
duration = end - start | ||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration) | ||
after_mem = get_mem_usage() | ||
print( | ||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" | ||
) | ||
print(f"Memory usage before: {before_mem}") | ||
print(f"Memory usage after: {after_mem}") | ||
|
||
return model | ||
|
||
|
||
args = parse_args() | ||
|
||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") | ||
or None) | ||
s3_secret_access_key = (args.s3_secret_access_key | ||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None) | ||
|
||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) | ||
|
||
_read_stream, _write_stream = (partial( | ||
stream_io.open_stream, | ||
mode=mode, | ||
s3_access_key_id=s3_access_key_id, | ||
s3_secret_access_key=s3_secret_access_key, | ||
s3_endpoint=s3_endpoint, | ||
) for mode in ("rb", "wb+")) | ||
|
||
model_ref = args.model | ||
|
||
model_name = model_ref.split("/")[1] | ||
|
||
os.environ["MASTER_ADDR"] = "127.0.0.1" | ||
os.environ["MASTER_PORT"] = "8080" | ||
|
||
torch.distributed.init_process_group(world_size=1, rank=0) | ||
initialize_model_parallel() | ||
|
||
keyfile = args.keyfile if args.keyfile else None | ||
|
||
if args.command == "serialize": | ||
input_dir = args.serialized_directory.rstrip('/') | ||
suffix = args.suffix if args.suffix else uuid.uuid4().hex | ||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" | ||
model_path = f"{base_path}/model.tensors" | ||
serialize() | ||
elif args.command == "deserialize": | ||
tensorizer_args = TensorizerArgs.from_cli_args(args) | ||
model_path = args.path_to_tensors | ||
deserialize() | ||
else: | ||
raise ValueError("Either serialize or deserialize must be specified.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.