Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Core] feat: Add model loading using tensorizer #3476

Merged
merged 102 commits into from
Apr 14, 2024

Conversation

sangstar
Copy link
Contributor

@sangstar sangstar commented Mar 18, 2024

Feature Request Issue

Tensorizer Support

This PR allows models used for the OpenAI-compatible API server to be loaded using
Coreweave's Tensorizer, enabling extremely fast (faster than cached Safetensors) model loads from HTTP/HTTPS,
Redis, and S3 endpoints.

The key changes involved are:

  1. Adds TensorizerConfig to the set of configs.
  2. Adds tensorizer_loader.py to vllm/model_executor that provides utility functions
    for tensorizer.
  3. Adds multiple args to the vLLM's OpenAI inference service entrypoint that allows
    the user to specify the path to serialized-by-tensorizer model tensors, as well as
    arguments for tensorizer's deserializer.
  4. Allows deserialization of serialized model tensors in HuggingFace's model format,
    as well as supporting deserializing serialized vLLM-formatted models, allowing the
    use of loading with plaid_mode, which can allow Llama 2 13B non-locally to start serving requests
    in as little as 10 seconds. Also supports encrypting and decrypting model tensors.
  5. Adds a tensorize_vllm_model.py script to examples/ that allows vLLM models to be serialized and
    deserialized with tensorizer.
  6. Adds tensorizer as an optional dependency.

Credentialing for S3 is supported by passing a user's access and secret key to
S3_ACCESS_KEY_ID and S3_SECRET_ACCESS_KEY environment variables respectively. It can also be specified as CLI args to the api server entrypoint.

Model loading benchmarks

Tensorizer can load models like Llama 2 13B in as little as 10 seconds. In order to do so, a model must be
serialized using TensorSerializer to a .tensors file located either locally or through a S3, HTTP/HTTPS, or Redis
endpoint. --tensorizer-uri must be specified with the serialized tensors location when invoking the API server.

Example usage:

python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model EleutherAI/pythia-6.9b \
--load-format tensorizer \
--tensorizer-uri s3://tensorized/EleutherAI/pythia-6.9b/fp16/model.tensors

If a vLLM model is serialized, plaid_mode can be used, which loads much faster. The following plot demonstrates model loading time benchmarks for vLLM's OpenAI-compatible inference server on a Nvidia A40 GPU.

Tensorizer is so fast that it loads models faster than Safetensors even locally.

Benchmark (10)

@sangstar
Copy link
Contributor Author

@cadedaniel @rkooo567 Pinging for an assigned reviewer from someone on the team when possible!

@sangstar sangstar force-pushed the sangstar/integrate-tensorizer branch 4 times, most recently from 64e8637 to 23a6c03 Compare April 3, 2024 18:39
@sangstar
Copy link
Contributor Author

sangstar commented Apr 3, 2024

@cadedaniel @rkooo567 @Yard1 @WoosukKwon @zhuohan123 @ywang96

All tests are passing. Can I get eyes on this please? Cheers!

sangstar added 18 commits April 4, 2024 09:35
This feature allows vLLM models to be loaded extremely fast using
`tensorizer`. `tensorizer` loads serialized model tensors from
HTTP/HTTPS, Redis, S3 endpoints, or locally, typically on the scale
of multiple GB/s.
This allows the deserializer to access S3 credentials when reading.
It extracts the access key from env variables `S3_ACCESS_KEY_ID`
and secret key from `S3_SECRET_ACCESS_KEY`.
Previous commit wasn't able to pass S3
credentials through `TensorDeserializer`, so
it is instead passed to `stream_io.openstream`
which is used to instantiate the
`TensorDeserializer`.
Removed functionality to allow for tensorizing without `plaid_mode`,
and updating to `tensorizer==2.8.0`
Replaces `download_dir` with `tensorizer_uri` as a `TensorizerArgs`
param. This is due to discussions on `download_dir` being a confusing
and ultimately not helpful parameter for the location of model tensors.
Instead, `download_dir` is back to the definition coinciding with
the convention set by HuggingFace; as a location to download weights
for caching. A new parameter takes its place: `tensorizer_uri`. This
specifically deals with locating model tensors for `tensorizer`.
Also slight formatting fix for warning when loading weights with
`download_dir` not set to `None`.
Integrated changes from
ssteel/tensorizer-support branch that allowed
for deserializing vLLM models.
Integrated previous support for vLLM-formatted
model loading with `tensorizer` that makes full
use of loading to the GPU with `plaid-mode`, as
well as falling back on being able to load
HuggingFace models for serving using the CPU so
that vLLM can perform its manual GPU loading.
Fixed some unnecessary formatting changes in `arg_utils.py`,
`weight_utils.py` and `model_loader.py` fixed improperly passing
`force_http` to `TensorDeserializer` rather than `open_stream`.
Misc. fixes from now resolved conversations. Mostly consisting of
changes to syntax, style, adding docstrings, and versioning.
`examples/tensorize_vllm_model.py` now correctly instantiates
vLLM-formatted models.
Replaced the model initialization process using `LLMEngine`, allowing
vLLM to handle and therefore optimize the initial model loading
process. Added testing for quantization.
@sangstar
Copy link
Contributor Author

sangstar commented Apr 13, 2024

@ywang96 @Yard1 @rkooo567

Thank you all very much for your reviews! I've implemented the changes from @ywang96 's comments. To summarize:

  • An error is raised if the tensor parallel size exceeds 1 and attempting to use Tensorizer (test added)
  • The serialization step in examples/tensorize_vllm_model.py now instantiates the model to serialize using LLMEngine
  • Meta tensors found when deserializing will raise an error
  • Removed forcing float16 from the parser for examples/tensorize_vllm_model.py
  • I've also additionally added a PerformanceWarning when trying to load a tensorized model with quantization, as that is a bit unstable at the moment (I may try to look in to this in another PR) (test added).
  • Added the Tensorizer testing folder for the CI suite

@sangstar
Copy link
Contributor Author

Some minor fixes to ensure the testing suite can run the tensorizer tests. All passing! Thanks very much again for the reviews @rkooo567 @Yard1 @ywang96 let me know if anything else needed! :)

@sangstar sangstar requested a review from ywang96 April 14, 2024 00:07
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you again @sangstar for all the work and test coverage on this PR to add this feature!

@ywang96 ywang96 merged commit 711a000 into vllm-project:main Apr 14, 2024
46 checks passed
@sangstar sangstar deleted the sangstar/integrate-tensorizer branch April 16, 2024 12:57
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
@bbrowning
Copy link

I was able to successfully test this in a vLLM 0.4.1 container running on OpenShift, both with models serialized with the Tensorizer library directly and for vLLM-serialized models. Once I cranked up the Pod's CPU and increased the num_readers parameter, I got about an 8x speedup in my case when loading the same model via vLLM-serialized tensorize files compared to not using tensorizer at all and just downloading the safetensors from S3 to local disk then loading with vLLM. This took my overall cold start time of this Pod from a bit over 4 minutes to 30 seconds. There may be even more performance available in my setup with additional tweaking, but this is already a great win.

INFO 05-03 11:11:38 tensorizer.py:337] Deserialized 14.5 GB in 15.21s, 953.1 MB/s

That's an awesome improvement, and thank you!

@sangstar
Copy link
Contributor Author

sangstar commented May 3, 2024

I was able to successfully test this in a vLLM 0.4.1 container running on OpenShift, both with models serialized with the Tensorizer library directly and for vLLM-serialized models. Once I cranked up the Pod's CPU and increased the num_readers parameter, I got about an 8x speedup in my case when loading the same model via vLLM-serialized tensorize files compared to not using tensorizer at all and just downloading the safetensors from S3 to local disk then loading with vLLM. This took my overall cold start time of this Pod from a bit over 4 minutes to 30 seconds. There may be even more performance available in my setup with additional tweaking, but this is already a great win.

INFO 05-03 11:11:38 tensorizer.py:337] Deserialized 14.5 GB in 15.21s, 953.1 MB/s

That's an awesome improvement, and thank you!

I'm thrilled to hear that! I currently actually have a new PR up #4208 that uses the full 2.9.0 release, has better usage documentation, and automated inferring a vLLM-serialized model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants