Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting embedding models #3187

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

jc9123
Copy link

@jc9123 jc9123 commented Mar 4, 2024

Added support for BGE-M3 which is an embedded model. The embedded vector is stored in the .embed of the returned Request Object. A example test case can be found in tests/models/test_embedded.py.

FIX #9847

@simon-mo simon-mo self-assigned this Mar 5, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

Thank you for the PR.

A couple comments:

  • The major thing that should be done to add BGE to vLLM is that the model should be re-implemented using the layers from vLLM. vLLM has optimized inference kernels for most layers in transformers, we should use these in the implementation. Ideally, you would also use the ColumnParallel and RowParallel layers to support multi-gpu inference as well. Currently, it looks like you are using the transformers model inside vLLM in the definition of BGE. Take a look at the implementation of some of the other models for inspiration. You will need to modify attention to be bidirectional, but they should be instructive

  • I do not think that we should be adding a new input params_dict to vLLM. Additionally, we should not be tokenizing the input inside the forward method. The inputs to encoder models are tokens, just like for LLMs. We should be able to leverage the existing tokenization / sequence group / input_metadata infrastructure for encoder models

@jc9123 jc9123 force-pushed the supporting_embedded_models branch from b1e7351 to f5810d5 Compare March 19, 2024 21:09
@simon-mo simon-mo changed the title Supporting embedded models Supporting embedding models Mar 19, 2024
@jc9123
Copy link
Author

jc9123 commented Mar 19, 2024

I've removed param_dict and and replaced it with the existing tokenization infrastructure. I think I've changed the inference layer to use vLLM's inference layer as well, but I may have missed something so let me know if I need to fix anything. Thanks!

@robertgshaw2-neuralmagic
Copy link
Collaborator

@jc9123 i will take a look

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Mar 24, 2024

I left a few tactical nits above which should be addressed directly.

However, there is significant work to do to get this PR into a place that is ready to merge. I would suggest we split up the work into a series of PRs.

Milestone 1: Basic functionality using Hugging Face models

In this step, we will get basic functionality working, meaning:

  • Model is implemented using HF transformers
  • Update the LLM front end to support LLM.encode(). This will receive input text and return an embedding.

See below for a detailed workplan to address Milestone 1

Additionally, make sure there are checks that turn off incompatible features are turned off for embedding models:

  • Tensor parallelism > 1
  • Speculative decoding
  • Automatic prefix caching
  • LORA
  • Quantization
  • CUDAGraphs

Milestone 2: Implement the model properly

In this step, we will implement the model using the vllm layer primitives rather than using the transformers implementation of the model. This will enable us to support tensor parallelism and have more optimized inference.

The goal is to remove the need to import transformers and instead use the vLLM layer primitives that are optimized for inference like:

  • ColumnParallelLinear, RowParallelLinear, Attention, etc.

Milestone 3: Support sparse vectors, colbert, normalization, etc

In this step, we will add more examples of types of embeddings that can be handled. See here for examples of what this means

Milestone 4: Add the OpenAI Server Front End

Milestone 5: Expand feature set [optional]

  • Multimodal embedding models (e.g. CLIP)
  • Support quantized models
  • Support quantized embeddings
  • Support variable length embeddings
  • [ others that I am missing ]

@@ -19,4 +19,4 @@
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

requirements.txt Outdated
@@ -13,5 +13,6 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
datasets >= 2.0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a dependency. If you need it for testing, add it to requirements-dev

@@ -128,6 +128,7 @@ def create_worker(cls: type,

cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

@@ -370,7 +370,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {}

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should not be touched

@@ -52,6 +52,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
embedded_model : bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the term embedding_model to match the terminology used by OAI

https://platform.openai.com/docs/api-reference/embeddings

@@ -2,6 +2,8 @@
import time
from typing import Dict, List, Optional, Tuple, Set, Union

import inspect
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

@@ -559,16 +561,19 @@ def execute_model(
(input_tokens, input_positions, input_metadata, sampling_metadata,
lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

@@ -536,18 +536,23 @@ def _process_model_outputs(
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups

# If prefix caching is enabled, mark all blocks in the sequence groups
# as completed so that future requests don't attempt to recompute them
if self.embedded_model :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the logic for making stopping decisions occurs in llm_engine.check_stop

This is correct logic (in that each sequence only needs to run once), but please move it into the check_stop function

@@ -0,0 +1,288 @@
from dataclasses import dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my feedback below on how this file needs to be reworked

@robertgshaw2-neuralmagic
Copy link
Collaborator

@jc9123 I made a PR on your branch to:

  • fix the nits listed here
  • get started implementing the XLMRobertaModel

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Mar 24, 2024

If you want to chat live ping me at [ rshaw at neuralmagic dot com ]

Milestone 1 Workplan

A) Accept my PR to address the nits

B) Finish the XLMRobertaModel definition started

In the PR I did against your branch, I setup the basics of the model, which creates the weights using the config during init and then loads the weights during load_weights() (which is called by vLLM when the model is being loaded.

from transformers import XLMRobertaConfig
from xlm_roberta import XLMRobertaModel

model_id = "BAAI/bge-m3"
config = XLMRobertaConfig.from_pretrained(model_id)
model = XLMRobertaModel(config=config)
model.load_weights(model_id)

You need to finish off this function by implementing the forward method. forward needs to convert from the input data passed around by vllm into the formats expected by the hugging face implementation of the model (which we are using here). You should handle both the batched and non-batched case

Note: BAAI/bge-m3 uses the XLMRobertaModel architecture. So In vLLM we will implement the XLMRobertaModel architecture and then any embedding model that uses this architecture can use the definition

C) Update the LLM front end UX

I would like the UX to look something like this:

from vllm import LLM, EmbeddingParams

# LLM should autodetect this is an embedding model since it uses `XLMRobertaModel`
model = LLM("BAAI/bge-m3")
parameters = EmbeddingParams()
sentences = ["The quick...", "brown fox...", ... "the lazy..."]

output = model.encode(sentences, parameters=parameters)

Then, we can add things like normalize, return_sparse, return_dense, return_colbert_vecs via the EmbeddingParams in the future. [ This does not need to be done for this PR ]

We can look at the following repos for potential features to support via encode

D) Clean up internals

Update RequestOutput in vllm/vllm/outputs.py
  • Create an EmbeddingOutput class
  • Modify RequestOutput to look like the following
class RequestOutput:
        request_id: str,
        prompt: str,
        prompt_token_ids: List[int],
        prompt_logprobs: Optional[PromptLogprobs],
        outputs: Union[List[CompletionOutput], List[EmbeddingOutput],  # << RequestOutput can be either Completion or Embedding
        finished: bool,
        metrics: Optional[RequestMetrics] = None,
        lora_request: Optional[LoRARequest] = None,
Fix terminology
  • embedded_model --> embedding_model
  • embed --> embedding
Move finish sequence logic to check_stop
  • Currently the logic is in llm_engine._process_model_output()
  • The logic should be in llm_engine._check_stop
Automatically detect that the model is an embedding model
  • The user should not have to specify that it is an embedding model
  • Somewhere in the vllm code, create a registry that selects which models are embedding and which models are decoders

^^^ Once this is all done, we can take another check-in

E) Turn Off KV Cache Memory usage

  • note: we still allocate the KV cache even when its not being used for embedding models

F) Update Pass EmbeddingParams around instead of SamplingParams

  • note: this is going to require a lot of work

G) Update parameters for max batching

  • We should be able to support many more tokens in a batch than we currently do for LLMs

@jc9123
Copy link
Author

jc9123 commented Mar 24, 2024

Hi Robert, thank you for the detail feedback! I will begin working on milestone 1 and let you know if I have any questions.

@robertgshaw2-neuralmagic
Copy link
Collaborator

@jc9123 updated comment above

@jc9123
Copy link
Author

jc9123 commented Mar 24, 2024

@jc9123 updated comment above

Thanks! I will check in with you after finishing steps a-d.

vllm/outputs.py Outdated
@@ -75,6 +75,7 @@ def __init__(
finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
embed = None
Copy link
Member

@ywang96 ywang96 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it's a better idea to create a separate class for output of embedding models. Many attributes of RequestOutput are tied to language models but not relevant to embedding models. We can discuss this in the later stages of this PR series.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree @ywang96

I think we will want to refactor RequestOutput and SamplingParams to be abstract with separate versions for Completion vs Embedding

It will be tricky to do this without breaking changes to the UX

@robertgshaw2-neuralmagic
Copy link
Collaborator

@jc9123 any update?

@jc9123
Copy link
Author

jc9123 commented Apr 5, 2024

@jc9123 any update?

Sorry for the delay but I will need a bit more time to finish milestone 1. I should be able to finish it by early next week.

@brandonsorensen
Copy link

I don't want to step on @jc9123's toes, but I'd be happy to help out on this issue. For example, I could already start work on the front end bits.

@jc9123
Copy link
Author

jc9123 commented Apr 10, 2024

@jc9123 any update?

Hi Robert, sorry for the delay but I've finished checkpoint 1.a->1.d. Could you take a look; right now, I've mostly copied EmbeddingParams and EmbeddingOutput from their Sampling counterpart to ensure that everything still work but was wondering if there is something specific I should add/change.

Edit : Found some errors in implementation, will fix by EOD.
Edit2: Fixed!

@CatherineSue
Copy link
Contributor

@jc9123 @robertgshaw2-neuralmagic , it seems this PR has a few milestones that are overlapped with my PR #3734. I have also separated EmbeddingOutput from CompletionOutput. Should we collaborate so we don't do duplicate work?

@jc9123
Copy link
Author

jc9123 commented Apr 15, 2024

@jc9123 @robertgshaw2-neuralmagic , it seems this PR has a few milestones that are overlapped with my PR #3734. I have also separated EmbeddingOutput from CompletionOutput. Should we collaborate so we don't do duplicate work?

Hi Catherine, we should definitely collaborate to avoid duplicate work. Between our PR, I think we should have the same implementation for the vLLM front end UX and also the EmbeddingOutput and EmbeddingParams class.

@KylinMountain
Copy link

Any update? Can we use other embedding model now?

@whybeyoung
Copy link

lgtm, but any update ?

@simon-mo
Copy link
Collaborator

Bart has been added, cross attention as well, i believe there should be some small incrementally work in hooking this models with the embedding API

If @jc9123 is not actively working on this anymore, please feel free to open a new PR base off this.

Copy link

mergify bot commented Oct 31, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @jc9123 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[New Model]: BAAI/bge-m3
8 participants