-
Notifications
You must be signed in to change notification settings - Fork 570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for LoRA adapters in vLLM inference engine #562
Conversation
OPE-395 Add support for SFT/LoRA models in VLLM inference engine
Like this: oumi/scripts/polaris/jobs/vllm_worker.sh Line 90 in fd3ebac
But in Python: https://github.com/oumi-ai/oumi/blob/main/src/oumi/inference/vllm_inference_engine.py |
src/oumi/core/types/turn.py
Outdated
@@ -106,6 +106,11 @@ def is_text(self) -> bool: | |||
"""Checks if the message contains text.""" | |||
return self.type == Type.TEXT | |||
|
|||
def __repr__(self): | |||
"""Returns a string representation of the message.""" | |||
content = self.content if self.is_text() else "<non-text-content>" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Define a constant for "<non-text-content>"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with the type of the message instead.
src/oumi/core/types/turn.py
Outdated
def __repr__(self): | ||
"""Returns a string representation of the message.""" | ||
content = self.content if self.is_text() else "<non-text-content>" | ||
return f"{self.role.upper()}: {content}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also include type
, id
? Any other small fields ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment above modifies message to mention the type if it's not text. Added ID.
src/oumi/core/types/turn.py
Outdated
def __repr__(self): | ||
"""Returns a string representation of the message.""" | ||
content = self.content if self.is_text() else "<non-text-content>" | ||
return f"{self.role.upper()}: {content}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why .upper()
for role
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes it more readable IMO.
ASSISTANT: How are you?
USER: Good
vs
assistant: How are you?
user: Good
src/oumi/core/types/turn.py
Outdated
def __repr__(self): | ||
"""Returns a string representation of the message.""" | ||
content = self.content if self.is_text() else "<non-text-content>" | ||
return f"{self.role.upper()}: {content}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of building the string manually, should we use some library to do formatting for us ?
For example: you create a temp dict
with fields of interest , then use json
, pprint.pformat
, or somesuch to convert it to string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO The logic here is light enough that manually creating the string works fine.
@@ -53,6 +61,8 @@ def __init__( | |||
quantization=quantization, | |||
tensor_parallel_size=tensor_parallel_size, | |||
enable_prefix_caching=enable_prefix_caching, | |||
enable_lora=self.lora_request is not None, | |||
max_model_len=model_params.model_max_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we sanitize this value before passing it to vllm ?
something like
max_model_len=(model_params.model_max_length if "... is not None and ... > 0" else None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO all validation should happen during config initialization, and downstream code like vLLM should be able to consume the configs as-is without additional validation. Added a validation check that model_max_length is a positive int if specified
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a test in tests/inference/test_vllm_inference_engine.py covering this case
Done, and tested on GCP. Also added a test for turn.py. |
Fixes OPE-395
Message
andConversation