Skip to content

Commit 31b1dd5

Browse files
committed
Make the remote inference engine runnable in jupyter notebooks.
1 parent e692b0d commit 31b1dd5

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

src/oumi/core/async_utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import asyncio
2+
from multiprocessing.pool import ThreadPool
3+
from typing import Any, Awaitable
4+
5+
6+
def safe_asyncio_run(main: Awaitable[Any]) -> Any:
7+
"""Run a series of Awaitables in a new thread. Blocks until the thread is finished.
8+
9+
This circumvents the issue of running async functions in the main thread when
10+
an event loop is already running (Jupyter notebooks, for example).
11+
12+
Prefer using `safe_asyncio_run` over `asyncio.run` to allow upstream callers to
13+
ignore our dependency on asyncio.
14+
15+
Args:
16+
main: The awaitable to resolve.
17+
18+
Returns:
19+
The result of the awaitable.
20+
"""
21+
pool = ThreadPool(processes=1)
22+
return pool.apply(asyncio.run, (main,))

src/oumi/inference/remote_inference_engine.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import aiohttp
77

8+
from oumi.core.async_utils import safe_asyncio_run
89
from oumi.core.configs import GenerationConfig, ModelParams, RemoteParams
910
from oumi.core.inference import BaseInferenceEngine
1011
from oumi.core.types.turn import Conversation, Message, Role, Type
@@ -244,7 +245,7 @@ def infer_online(
244245
"""
245246
if not generation_config.remote_params:
246247
raise ValueError("Remote params must be provided in generation_config.")
247-
conversations = asyncio.run(
248+
conversations = safe_asyncio_run(
248249
self._infer(input, generation_config, generation_config.remote_params)
249250
)
250251
if generation_config.output_filepath:
@@ -271,7 +272,7 @@ def infer_from_file(
271272
if not generation_config.remote_params:
272273
raise ValueError("Remote params must be provided in generation_config.")
273274
input = self._read_conversations(input_filepath)
274-
conversations = asyncio.run(
275+
conversations = safe_asyncio_run(
275276
self._infer(input, generation_config, generation_config.remote_params)
276277
)
277278
if generation_config.output_filepath:

tests/core/test_async_utils.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import asyncio
2+
import re
3+
4+
import pytest
5+
6+
from oumi.core.async_utils import safe_asyncio_run
7+
8+
9+
def test_safe_asyncio_run_nested():
10+
async def nested():
11+
return 1
12+
13+
def method_using_asyncio():
14+
return asyncio.run(nested())
15+
16+
def method_using_safe_asyncio_run():
17+
return safe_asyncio_run(nested())
18+
19+
with pytest.raises(
20+
RuntimeError,
21+
match=re.escape("asyncio.run() cannot be called from a running event loop"),
22+
):
23+
24+
async def main_async():
25+
return method_using_asyncio()
26+
27+
# This will raise a RuntimeError because we are trying to run an async function
28+
# inside a running event loop.
29+
asyncio.run(main_async())
30+
31+
async def safe_main():
32+
return method_using_safe_asyncio_run()
33+
34+
result = safe_asyncio_run(safe_main())
35+
assert result == 1

0 commit comments

Comments
 (0)