Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting embedding models #3187

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
31c2cdd
added OpenAI embedded model API call
jc9123 Feb 9, 2024
b81344b
added openai embedded api call
jc9123 Feb 9, 2024
9f74015
added openai embedded api call
jc9123 Feb 9, 2024
be10d8c
save
jc9123 Feb 29, 2024
21633a4
added embedded model
jc9123 Mar 4, 2024
d84df2e
Merge branch 'main' into supporting_embedded_model
jc9123 Mar 4, 2024
cf308fd
fixed tokenizing and cleaned up code for bgem3
jc9123 Mar 19, 2024
bdfd0ba
removed prompt dict
jc9123 Mar 19, 2024
efaa551
Merge remote-tracking branch 'upstream/main' into add_embedded_model
jc9123 Mar 19, 2024
f5810d5
fixed merge conflict
jc9123 Mar 19, 2024
5fb0a61
added xlm roberta scaffold
robertgshaw2-neuralmagic Mar 24, 2024
a3cc834
fixed nits
robertgshaw2-neuralmagic Mar 24, 2024
6d41610
fixed nit again
robertgshaw2-neuralmagic Mar 24, 2024
f64eddb
fixed nits again again
robertgshaw2-neuralmagic Mar 24, 2024
8064631
fixed nits again again again
robertgshaw2-neuralmagic Mar 24, 2024
7296a4c
newline on Attention
robertgshaw2-neuralmagic Mar 24, 2024
a12170c
Merge pull request #1 from neuralmagic/rs/embedding-step-1
jc9123 Mar 24, 2024
e9c4d77
finished checkpoint 1.a
jc9123 Apr 10, 2024
f71981c
cleaned up checkpoint 1.a
jc9123 Apr 10, 2024
9f37dc0
added embedded_model registry
jc9123 Apr 10, 2024
b47f3a2
code cleanup & bugfix
jc9123 Apr 12, 2024
619167b
more cleanup
jc9123 Apr 12, 2024
b8e2adc
more cleanup
jc9123 Apr 12, 2024
4c532fd
save
jc9123 Apr 13, 2024
0f194af
fixed xlm_roberta
jc9123 Apr 13, 2024
5923b4d
code cleanup
jc9123 Apr 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
datasets >= 2.0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a dependency. If you need it for testing, add it to requirements-dev

outlines == 0.0.34
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
34 changes: 34 additions & 0 deletions tests/models/test_embedded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from vllm import LLM, SamplingParams
import numpy as np
# Sample prompts.
sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = ["BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="BAAI/bge-m3", enforce_eager = True, embedded_model = True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs1 = llm.generate(sentences_1, sampling_params)

lst1 = []
for output1 in outputs1:
generated_text = output1.embed.cpu()
lst1.append(np.array(generated_text))
lst1 = np.array(lst1)
outputs2 = llm.generate(sentences_2, sampling_params)

lst2 = []
for output2 in outputs2:
prompt = output2.prompt
generated_text = output2.embed.cpu()
lst2.append(np.array(generated_text))
lst2 = np.array(lst2)
result = lst1 @ lst2.T
expected_result = np.array([[0.6265, 0.3477], [0.3499, 0.678 ]])

assert(np.isclose(result, expected_result, atol=1e-2).all())
print("Passed!")
1 change: 1 addition & 0 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def create_worker(cls: type,

cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

worker.init_cache_engine(cache_config)
worker.warm_up_model()

Expand Down
1 change: 0 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {}

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file should not be touched

seq_id = seq.seq_id
seq_data[seq_id] = seq.data
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
embedded_model : bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use the term embedding_model to match the terminology used by OAI

https://platform.openai.com/docs/api-reference/embeddings


def __post_init__(self):
if self.tokenizer is None:
Expand Down
23 changes: 14 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(
f"device_config={device_config.device}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - no need to change this

self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down Expand Up @@ -375,6 +374,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
parent_seq.seq_id: []
for parent_seq in parent_seqs
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - not need to change this

for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
Expand Down Expand Up @@ -536,18 +536,23 @@ def _process_model_outputs(
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups

# If prefix caching is enabled, mark all blocks in the sequence groups
# as completed so that future requests don't attempt to recompute them
if self.embedded_model :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the logic for making stopping decisions occurs in llm_engine.check_stop

This is correct logic (in that each sequence only needs to run once), but please move it into the check_stop function

for i, seq_group in enumerate(scheduled_seq_groups):
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_STOPPED
seq_group.embed = output[i]
else:
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)

if self.cache_config.enable_prefix_caching:
for seq_group in scheduled_seq_groups:
self.scheduler.mark_blocks_as_computed(seq_group)

for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)


# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

request_outputs: List[RequestOutput] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - duplicate

# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups:
Expand All @@ -561,7 +566,7 @@ def _process_model_outputs(
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - do not change this

return request_outputs

def step(self) -> List[RequestOutput]:
Expand Down Expand Up @@ -624,7 +629,7 @@ def step(self) -> List[RequestOutput]:
scheduler_outputs.blocks_to_copy)
else:
output = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

return self._process_model_outputs(output, scheduler_outputs)

def do_log_stats(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
embedded_model: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
Expand All @@ -104,9 +105,11 @@ def __init__(
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
embedded_model = embedded_model,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.llm_engine.embedded_model = embedded_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not directly set the embedding_model member here.

Instead, update the EngineArgs to have the embedding_model member. Then, LLMEngine.from_engine_args should set the member

self.request_counter = Counter()

def get_tokenizer(
Expand Down
12 changes: 12 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())


@app.post("/v1/embeddings")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this does not work yet?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the OAI server edits from this PR.

We will handle this in a future step

async def create_embeddings(request: EmbeddingRequest):

## need to implement
generator = await openai_serving_completion.create_completion()
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
return JSONResponse(content=generator.model_dump())


if __name__ == "__main__":
args = parse_args()

Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class ErrorResponse(BaseModel):
code: int


class EmbeddingRequest(BaseModel):
input: str
model: str


class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,7 @@ def request_output_to_completion_response(
choices=choices,
usage=usage,
)


async def create_embeddings(self, input: str, model: str):
raise NotImplementedError
5 changes: 0 additions & 5 deletions vllm/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from vllm.model_executor.layers.attention.attention import Attention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched


__all__ = [
"Attention",
]
1 change: 0 additions & 1 deletion vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should not be touched

# Get the (maybe quantized) linear method.
linear_method = None
if model_config.quantization is not None:
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
# embedded model
"XLMRobertaModel": ("bgem3", "BGEM3FlagForCausalLM"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be updated to:

"XLMRobertaModel": ("xlm_roberta", "XLMRobertaModel")

Once we have migrated the model definition

}

# Models not supported by ROCm.
Expand Down
Loading
Loading