Skip to content

Commit 916b990

Browse files
committed
Address comments
1 parent 5d77678 commit 916b990

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

src/oumi/core/configs/params/model_params.py

+3
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,6 @@ def __validate__(self):
196196
"Sharded-model evaluations with LM Harness should be invoked with "
197197
"`python`, not `accelerate launch`."
198198
)
199+
200+
if self.model_max_length is not None and self.model_max_length <= 0:
201+
raise ValueError("model_max_length must be a positive integer or None.")

src/oumi/core/types/turn.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,13 @@ def is_text(self) -> bool:
106106
"""Checks if the message contains text."""
107107
return self.type == Type.TEXT
108108

109-
def __repr__(self):
109+
def __repr__(self) -> str:
110110
"""Returns a string representation of the message."""
111-
content = self.content if self.is_text() else "<non-text-content>"
112-
return f"{self.role.upper()}: {content}"
111+
id_str = ""
112+
if self.id:
113+
id_str = f"{self.id} - "
114+
content = (self.content or "") if self.is_text() else f"<{self.type.upper() }>"
115+
return f"{id_str}{self.role.upper()}: {content}"
113116

114117

115118
class Conversation(pydantic.BaseModel):
@@ -187,7 +190,7 @@ def filter_messages(self, role: Optional[Role] = None) -> List[Message]:
187190
messages = self.messages
188191
return messages
189192

190-
def __repr__(self):
193+
def __repr__(self) -> str:
191194
"""Returns a string representation of the conversation."""
192195
return "\n".join([repr(m) for m in self.messages])
193196

src/oumi/inference/vllm_inference_engine.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ def __init__(
4141
"vLLM is not installed. "
4242
"Please install the GPU dependencies for this package."
4343
)
44-
self.lora_request = None
44+
self._lora_request = None
4545
if model_params.adapter_model:
46-
self.lora_request = vllm.lora.request.LoRARequest(
47-
lora_name="my_lora_adapter",
46+
# ID should be unique for this adapter, but isn't enforced by vLLM.
47+
self._lora_request = vllm.lora.request.LoRARequest(
48+
lora_name="oumi_lora_adapter",
4849
lora_int_id=1,
4950
lora_path=model_params.adapter_model,
5051
)
@@ -61,7 +62,7 @@ def __init__(
6162
quantization=quantization,
6263
tensor_parallel_size=tensor_parallel_size,
6364
enable_prefix_caching=enable_prefix_caching,
64-
enable_lora=self.lora_request is not None,
65+
enable_lora=self._lora_request is not None,
6566
max_model_len=model_params.model_max_length,
6667
)
6768
# Ensure the tokenizer is set properly
@@ -111,7 +112,7 @@ def _infer(
111112
chat_response = self._llm.chat(
112113
vllm_input,
113114
sampling_params=sampling_params,
114-
lora_request=self.lora_request,
115+
lora_request=self._lora_request,
115116
)
116117
new_messages = [
117118
Message(content=message.outputs[0].text, role=Role.ASSISTANT)

0 commit comments

Comments
 (0)