-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting embedding models #3187
base: main
Are you sure you want to change the base?
Changes from 10 commits
31c2cdd
b81344b
9f74015
be10d8c
21633a4
d84df2e
cf308fd
bdfd0ba
efaa551
f5810d5
5fb0a61
a3cc834
6d41610
f64eddb
8064631
7296a4c
a12170c
e9c4d77
f71981c
9f37dc0
b47f3a2
619167b
b8e2adc
4c532fd
0f194af
5923b4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,5 +13,6 @@ pydantic >= 2.0 # Required for OpenAI server. | |
prometheus_client >= 0.18.0 | ||
pynvml == 11.5.0 | ||
triton >= 2.1.0 | ||
datasets >= 2.0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be a dependency. If you need it for testing, add it to |
||
outlines == 0.0.34 | ||
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from vllm import LLM, SamplingParams | ||
import numpy as np | ||
# Sample prompts. | ||
sentences_1 = ["What is BGE M3?", "Defination of BM25"] | ||
sentences_2 = ["BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", | ||
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"] | ||
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||
|
||
# Create an LLM. | ||
# llm = LLM(model="facebook/opt-125m") | ||
llm = LLM(model="BAAI/bge-m3", enforce_eager = True, embedded_model = True) | ||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
outputs1 = llm.generate(sentences_1, sampling_params) | ||
|
||
lst1 = [] | ||
for output1 in outputs1: | ||
generated_text = output1.embed.cpu() | ||
lst1.append(np.array(generated_text)) | ||
lst1 = np.array(lst1) | ||
outputs2 = llm.generate(sentences_2, sampling_params) | ||
|
||
lst2 = [] | ||
for output2 in outputs2: | ||
prompt = output2.prompt | ||
generated_text = output2.embed.cpu() | ||
lst2.append(np.array(generated_text)) | ||
lst2 = np.array(lst2) | ||
result = lst1 @ lst2.T | ||
expected_result = np.array([[0.6265, 0.3477], [0.3499, 0.678 ]]) | ||
|
||
assert(np.isclose(result, expected_result, atol=1e-2).all()) | ||
print("Passed!") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,6 +128,7 @@ def create_worker(cls: type, | |
|
||
cache_config.num_gpu_blocks = num_gpu_blocks | ||
cache_config.num_cpu_blocks = 0 | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should not be touched |
||
worker.init_cache_engine(cache_config) | ||
worker.warm_up_model() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -370,7 +370,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: | |
|
||
seq_data: Dict[int, SequenceData] = {} | ||
block_tables: Dict[int, List[int]] = {} | ||
|
||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this file should not be touched |
||
seq_id = seq.seq_id | ||
seq_data[seq_id] = seq.data | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ class EngineArgs: | |
max_cpu_loras: Optional[int] = None | ||
device: str = 'auto' | ||
ray_workers_use_nsight: bool = False | ||
embedded_model : bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's use the term |
||
|
||
def __post_init__(self): | ||
if self.tokenizer is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,6 @@ def __init__( | |
f"device_config={device_config.device}, " | ||
f"seed={model_config.seed})") | ||
# TODO(woosuk): Print more configs in debug mode. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - no need to change this |
||
self.model_config = model_config | ||
self.cache_config = cache_config | ||
self.lora_config = lora_config | ||
|
@@ -375,6 +374,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, | |
parent_seq.seq_id: [] | ||
for parent_seq in parent_seqs | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - not need to change this |
||
for sample in samples: | ||
parent_child_dict[sample.parent_seq_id].append(sample) | ||
# List of (child, parent) | ||
|
@@ -536,18 +536,23 @@ def _process_model_outputs( | |
# Update the scheduled sequence groups with the model outputs. | ||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups | ||
|
||
# If prefix caching is enabled, mark all blocks in the sequence groups | ||
# as completed so that future requests don't attempt to recompute them | ||
if self.embedded_model : | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of the logic for making stopping decisions occurs in This is correct logic (in that each sequence only needs to run once), but please move it into the |
||
for i, seq_group in enumerate(scheduled_seq_groups): | ||
for seq in seq_group.get_seqs(): | ||
seq.status = SequenceStatus.FINISHED_STOPPED | ||
seq_group.embed = output[i] | ||
else: | ||
for seq_group, outputs in zip(scheduled_seq_groups, output): | ||
self._process_sequence_group_outputs(seq_group, outputs) | ||
|
||
if self.cache_config.enable_prefix_caching: | ||
for seq_group in scheduled_seq_groups: | ||
self.scheduler.mark_blocks_as_computed(seq_group) | ||
|
||
for seq_group, outputs in zip(scheduled_seq_groups, output): | ||
self._process_sequence_group_outputs(seq_group, outputs) | ||
|
||
|
||
# Free the finished sequence groups. | ||
self.scheduler.free_finished_seq_groups() | ||
|
||
request_outputs: List[RequestOutput] = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - duplicate |
||
# Create the outputs. | ||
request_outputs: List[RequestOutput] = [] | ||
for seq_group in scheduled_seq_groups: | ||
|
@@ -561,7 +566,7 @@ def _process_model_outputs( | |
# Log stats. | ||
if self.log_stats: | ||
self.stat_logger.log(self._get_stats(scheduler_outputs)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - do not change this |
||
return request_outputs | ||
|
||
def step(self) -> List[RequestOutput]: | ||
|
@@ -624,7 +629,7 @@ def step(self) -> List[RequestOutput]: | |
scheduler_outputs.blocks_to_copy) | ||
else: | ||
output = [] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit |
||
return self._process_model_outputs(output, scheduler_outputs) | ||
|
||
def do_log_stats(self) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,7 @@ def __init__( | |
enforce_eager: bool = False, | ||
max_context_len_to_capture: int = 8192, | ||
disable_custom_all_reduce: bool = False, | ||
embedded_model: bool = False, | ||
**kwargs, | ||
) -> None: | ||
if "disable_log_stats" not in kwargs: | ||
|
@@ -104,9 +105,11 @@ def __init__( | |
enforce_eager=enforce_eager, | ||
max_context_len_to_capture=max_context_len_to_capture, | ||
disable_custom_all_reduce=disable_custom_all_reduce, | ||
embedded_model = embedded_model, | ||
**kwargs, | ||
) | ||
self.llm_engine = LLMEngine.from_engine_args(engine_args) | ||
self.llm_engine.embedded_model = embedded_model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not directly set the Instead, update the |
||
self.request_counter = Counter() | ||
|
||
def get_tokenizer( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,6 +112,18 @@ async def create_completion(request: CompletionRequest, raw_request: Request): | |
return JSONResponse(content=generator.model_dump()) | ||
|
||
|
||
@app.post("/v1/embeddings") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this does not work yet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the OAI server edits from this PR. We will handle this in a future step |
||
async def create_embeddings(request: EmbeddingRequest): | ||
|
||
## need to implement | ||
generator = await openai_serving_completion.create_completion() | ||
if isinstance(generator, ErrorResponse): | ||
return JSONResponse(content=generator.model_dump(), | ||
status_code=generator.code) | ||
else: | ||
return JSONResponse(content=generator.model_dump()) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +0,0 @@ | ||
from vllm.model_executor.layers.attention.attention import Attention | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should not be touched |
||
|
||
__all__ = [ | ||
"Attention", | ||
] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,6 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, | |
**kwargs) -> nn.Module: | ||
lora_config = kwargs.get("lora_config", None) | ||
model_class = _get_model_architecture(model_config) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file should not be touched |
||
# Get the (maybe quantized) linear method. | ||
linear_method = None | ||
if model_config.quantization is not None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,8 @@ | |
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), | ||
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), | ||
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), | ||
# embedded model | ||
"XLMRobertaModel": ("bgem3", "BGEM3FlagForCausalLM"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be updated to: "XLMRobertaModel": ("xlm_roberta", "XLMRobertaModel") Once we have migrated the model definition |
||
} | ||
|
||
# Models not supported by ROCm. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should not be touched