|
1 |
| -import queue |
2 |
| -import threading |
3 | 1 | from abc import ABC, abstractmethod
|
4 | 2 | from pathlib import Path
|
5 | 3 | from typing import List, Optional
|
|
14 | 12 | class BaseInferenceEngine(ABC):
|
15 | 13 | """Base class for running model inference."""
|
16 | 14 |
|
17 |
| - def __init__(self): |
18 |
| - """Initializes the BaseInferenceEngine. |
19 |
| -
|
20 |
| - Sets up a queue and a background thread for writing conversations to files. |
21 |
| - """ |
22 |
| - self._write_queue = queue.Queue() |
23 |
| - |
24 |
| - def _write_conversation_thread(): |
25 |
| - while True: |
26 |
| - conversation, output_filepath = self._write_queue.get() |
27 |
| - # Make the directory if it doesn't exist. |
28 |
| - Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) |
29 |
| - with jsonlines.open(output_filepath, mode="a") as writer: |
30 |
| - json_obj = conversation.model_dump() |
31 |
| - writer.write(json_obj) |
32 |
| - self._write_queue.task_done() |
33 |
| - |
34 |
| - threading.Thread(target=_write_conversation_thread, daemon=True).start() |
35 |
| - |
36 |
| - def __del__(self): |
37 |
| - """Closes the write queue before being deleted.""" |
38 |
| - self._write_queue.join() |
39 |
| - |
40 | 15 | def infer(
|
41 | 16 | self,
|
42 | 17 | input: Optional[List[Conversation]] = None,
|
@@ -102,11 +77,21 @@ def _save_conversation(
|
102 | 77 | conversation: A single conversation to save.
|
103 | 78 | output_filepath: The filepath to where the conversation should be saved.
|
104 | 79 | """
|
105 |
| - self._write_queue.put((conversation, output_filepath)) |
| 80 | + Path(output_filepath).parent.mkdir(parents=True, exist_ok=True) |
| 81 | + with jsonlines.open(output_filepath, mode="a") as writer: |
| 82 | + json_obj = conversation.model_dump() |
| 83 | + writer.write(json_obj) |
106 | 84 |
|
107 |
| - def _finish_writing(self): |
108 |
| - """Blocks until all conversations are written to file.""" |
109 |
| - self._write_queue.join() |
| 85 | + async def _save_conversation_async( |
| 86 | + self, conversation: Conversation, output_filepath: str |
| 87 | + ) -> None: |
| 88 | + """Asynchronously saves single conversation to a file in Oumi chat format. |
| 89 | +
|
| 90 | + Args: |
| 91 | + conversation: A single conversation to save. |
| 92 | + output_filepath: The filepath to where the conversation should be saved. |
| 93 | + """ |
| 94 | + return self._save_conversation(conversation, output_filepath) |
110 | 95 |
|
111 | 96 | @abstractmethod
|
112 | 97 | def infer_online(
|
|
0 commit comments