-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting embedding models #3187
base: main
Are you sure you want to change the base?
Conversation
Thank you for the PR. A couple comments:
|
b1e7351
to
f5810d5
Compare
I've removed param_dict and and replaced it with the existing tokenization infrastructure. I think I've changed the inference layer to use vLLM's inference layer as well, but I may have missed something so let me know if I need to fix anything. Thanks! |
@jc9123 i will take a look |
I left a few tactical nits above which should be addressed directly. However, there is significant work to do to get this PR into a place that is ready to merge. I would suggest we split up the work into a series of PRs. Milestone 1: Basic functionality using Hugging Face modelsIn this step, we will get basic functionality working, meaning:
See below for a detailed workplan to address Milestone 1 Additionally, make sure there are checks that turn off incompatible features are turned off for embedding models:
Milestone 2: Implement the model properlyIn this step, we will implement the model using the vllm layer primitives rather than using the transformers implementation of the model. This will enable us to support tensor parallelism and have more optimized inference. The goal is to remove the need to import transformers and instead use the vLLM layer primitives that are optimized for inference like:
Milestone 3: Support sparse vectors, colbert, normalization, etcIn this step, we will add more examples of types of embeddings that can be handled. See here for examples of what this means Milestone 4: Add the OpenAI Server Front End
Milestone 5: Expand feature set [optional]
|
examples/offline_inference.py
Outdated
@@ -19,4 +19,4 @@ | |||
for output in outputs: | |||
prompt = output.prompt | |||
generated_text = output.outputs[0].text | |||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | |||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be touched
requirements.txt
Outdated
@@ -13,5 +13,6 @@ pydantic >= 2.0 # Required for OpenAI server. | |||
prometheus_client >= 0.18.0 | |||
pynvml == 11.5.0 | |||
triton >= 2.1.0 | |||
datasets >= 2.0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be a dependency. If you need it for testing, add it to requirements-dev
tests/spec_decode/utils.py
Outdated
@@ -128,6 +128,7 @@ def create_worker(cls: type, | |||
|
|||
cache_config.num_gpu_blocks = num_gpu_blocks | |||
cache_config.num_cpu_blocks = 0 | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be touched
vllm/core/scheduler.py
Outdated
@@ -370,7 +370,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: | |||
|
|||
seq_data: Dict[int, SequenceData] = {} | |||
block_tables: Dict[int, List[int]] = {} | |||
|
|||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file should not be touched
vllm/engine/arg_utils.py
Outdated
@@ -52,6 +52,7 @@ class EngineArgs: | |||
max_cpu_loras: Optional[int] = None | |||
device: str = 'auto' | |||
ray_workers_use_nsight: bool = False | |||
embedded_model : bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the term embedding_model
to match the terminology used by OAI
vllm/worker/model_runner.py
Outdated
@@ -2,6 +2,8 @@ | |||
import time | |||
from typing import Dict, List, Optional, Tuple, Set, Union | |||
|
|||
import inspect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
vllm/worker/model_runner.py
Outdated
if self.lora_config: | ||
self.set_active_loras(lora_requests, lora_mapping) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
vllm/worker/model_runner.py
Outdated
@@ -559,16 +561,19 @@ def execute_model( | |||
(input_tokens, input_positions, input_metadata, sampling_metadata, | |||
lora_requests, | |||
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
vllm/engine/llm_engine.py
Outdated
@@ -536,18 +536,23 @@ def _process_model_outputs( | |||
# Update the scheduled sequence groups with the model outputs. | |||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups | |||
|
|||
# If prefix caching is enabled, mark all blocks in the sequence groups | |||
# as completed so that future requests don't attempt to recompute them | |||
if self.embedded_model : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of the logic for making stopping decisions occurs in llm_engine.check_stop
This is correct logic (in that each sequence only needs to run once), but please move it into the check_stop
function
vllm/model_executor/models/bgem3.py
Outdated
@@ -0,0 +1,288 @@ | |||
from dataclasses import dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my feedback below on how this file needs to be reworked
@jc9123 I made a PR on your branch to:
|
If you want to chat live ping me at [ rshaw at neuralmagic dot com ] Milestone 1 WorkplanA) Accept my PR to address the nitsB) Finish the
|
Rs/embedding step 1
Hi Robert, thank you for the detail feedback! I will begin working on milestone 1 and let you know if I have any questions. |
@jc9123 updated comment above |
Thanks! I will check in with you after finishing steps a-d. |
vllm/outputs.py
Outdated
@@ -75,6 +75,7 @@ def __init__( | |||
finished: bool, | |||
metrics: Optional[RequestMetrics] = None, | |||
lora_request: Optional[LoRARequest] = None, | |||
embed = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it's a better idea to create a separate class for output of embedding models. Many attributes of RequestOutput
are tied to language models but not relevant to embedding models. We can discuss this in the later stages of this PR series.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree @ywang96
I think we will want to refactor RequestOutput
and SamplingParams
to be abstract with separate versions for Completion
vs Embedding
It will be tricky to do this without breaking changes to the UX
@jc9123 any update? |
Sorry for the delay but I will need a bit more time to finish milestone 1. I should be able to finish it by early next week. |
I don't want to step on @jc9123's toes, but I'd be happy to help out on this issue. For example, I could already start work on the front end bits. |
Hi Robert, sorry for the delay but I've finished checkpoint 1.a->1.d. Could you take a look; right now, I've mostly copied EmbeddingParams and EmbeddingOutput from their Sampling counterpart to ensure that everything still work but was wondering if there is something specific I should add/change. Edit : Found some errors in implementation, will fix by EOD. |
@jc9123 @robertgshaw2-neuralmagic , it seems this PR has a few milestones that are overlapped with my PR #3734. I have also separated EmbeddingOutput from CompletionOutput. Should we collaborate so we don't do duplicate work? |
Hi Catherine, we should definitely collaborate to avoid duplicate work. Between our PR, I think we should have the same implementation for the vLLM front end UX and also the EmbeddingOutput and EmbeddingParams class. |
Any update? Can we use other embedding model now? |
lgtm, but any update ? |
Bart has been added, cross attention as well, i believe there should be some small incrementally work in hooking this models with the embedding API If @jc9123 is not actively working on this anymore, please feel free to open a new PR base off this. |
This pull request has merge conflicts that must be resolved before it can be |
Added support for BGE-M3 which is an embedded model. The embedded vector is stored in the .embed of the returned Request Object. A example test case can be found in tests/models/test_embedded.py.
FIX #9847