Skip to content

Commit 341d687

Browse files
committed
Removed multithreading
1 parent 64cac43 commit 341d687

5 files changed

+19
-35
lines changed

src/oumi/core/inference/base_inference_engine.py

+14-29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import queue
2-
import threading
31
from abc import ABC, abstractmethod
42
from pathlib import Path
53
from typing import List, Optional
@@ -14,29 +12,6 @@
1412
class BaseInferenceEngine(ABC):
1513
"""Base class for running model inference."""
1614

17-
def __init__(self):
18-
"""Initializes the BaseInferenceEngine.
19-
20-
Sets up a queue and a background thread for writing conversations to files.
21-
"""
22-
self._write_queue = queue.Queue()
23-
24-
def _write_conversation_thread():
25-
while True:
26-
conversation, output_filepath = self._write_queue.get()
27-
# Make the directory if it doesn't exist.
28-
Path(output_filepath).parent.mkdir(parents=True, exist_ok=True)
29-
with jsonlines.open(output_filepath, mode="a") as writer:
30-
json_obj = conversation.model_dump()
31-
writer.write(json_obj)
32-
self._write_queue.task_done()
33-
34-
threading.Thread(target=_write_conversation_thread, daemon=True).start()
35-
36-
def __del__(self):
37-
"""Closes the write queue before being deleted."""
38-
self._write_queue.join()
39-
4015
def infer(
4116
self,
4217
input: Optional[List[Conversation]] = None,
@@ -102,11 +77,21 @@ def _save_conversation(
10277
conversation: A single conversation to save.
10378
output_filepath: The filepath to where the conversation should be saved.
10479
"""
105-
self._write_queue.put((conversation, output_filepath))
80+
Path(output_filepath).parent.mkdir(parents=True, exist_ok=True)
81+
with jsonlines.open(output_filepath, mode="a") as writer:
82+
json_obj = conversation.model_dump()
83+
writer.write(json_obj)
10684

107-
def _finish_writing(self):
108-
"""Blocks until all conversations are written to file."""
109-
self._write_queue.join()
85+
async def _save_conversation_async(
86+
self, conversation: Conversation, output_filepath: str
87+
) -> None:
88+
"""Asynchronously saves single conversation to a file in Oumi chat format.
89+
90+
Args:
91+
conversation: A single conversation to save.
92+
output_filepath: The filepath to where the conversation should be saved.
93+
"""
94+
return self._save_conversation(conversation, output_filepath)
11095

11196
@abstractmethod
11297
def infer_online(

src/oumi/inference/llama_cpp_inference_engine.py

-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def _infer(
187187
)
188188
output_conversations.append(new_conversation)
189189

190-
self._finish_writing()
191190
return output_conversations
192191

193192
def infer_online(

src/oumi/inference/native_text_inference_engine.py

-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _infer(
125125
)
126126
output_conversations.append(new_conversation)
127127

128-
self._finish_writing()
129128
return output_conversations
130129

131130
def infer_online(

src/oumi/inference/remote_inference_engine.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ async def _query_api(
185185
response_json, conversation
186186
)
187187
if generation_config.output_filepath:
188-
self._save_conversation(
188+
await self._save_conversation_async(
189189
result,
190190
generation_config.output_filepath,
191191
)
@@ -217,6 +217,7 @@ async def _infer(
217217
"""
218218
# Limit number of HTTP connections to the number of workers.
219219
connector = aiohttp.TCPConnector(limit=remote_params.num_workers)
220+
self._save_tasks = []
220221
# Control the number of concurrent tasks via a semaphore.
221222
semaphore = asyncio.BoundedSemaphore(remote_params.num_workers)
222223
async with aiohttp.ClientSession(connector=connector) as session:
@@ -232,8 +233,8 @@ async def _infer(
232233
for conversation in input
233234
]
234235
)
235-
self._finish_writing()
236-
return conversations
236+
237+
return conversations
237238

238239
def infer_online(
239240
self,

src/oumi/inference/vllm_inference_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _infer(
119119
new_conversation, generation_config.output_filepath
120120
)
121121
output_conversations.append(new_conversation)
122-
self._finish_writing()
122+
123123
return output_conversations
124124

125125
def infer_online(

0 commit comments

Comments
 (0)