Skip to content

Commit ff3f695

Browse files
authored
Merge pull request #3 from neuralmagic/request-gen-fixes
Refactor RequestGenerator to use threading and update test suite
2 parents 8afa579 + 407079a commit ff3f695

File tree

6 files changed

+152
-53
lines changed

6 files changed

+152
-53
lines changed

src/guidellm/request/base.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
import asyncio
1+
import threading
2+
import time
23
from abc import ABC, abstractmethod
4+
from queue import Empty, Full, Queue
35
from typing import Iterator, Optional, Union
46

57
from loguru import logger
68
from transformers import AutoTokenizer, PreTrainedTokenizer
79

810
from guidellm.core.request import TextGenerationRequest
11+
from guidellm.utils import STANDARD_SLEEP_INTERVAL
912

1013
__all__ = ["RequestGenerator"]
1114

@@ -31,20 +34,28 @@ def __init__(
3134
):
3235
self._async_queue_size = async_queue_size
3336
self._mode = mode
34-
self._queue = asyncio.Queue(maxsize=async_queue_size)
35-
self._stop_event = asyncio.Event()
37+
self._queue = Queue(maxsize=async_queue_size)
38+
self._stop_event = threading.Event()
3639

3740
if tokenizer is not None:
3841
self._tokenizer = (
3942
AutoTokenizer.from_pretrained(tokenizer)
4043
if isinstance(tokenizer, str)
4144
else tokenizer
4245
)
43-
logger.info(f"Tokenizer initialized: {self._tokenizer}")
46+
logger.info("Tokenizer initialized: {}", self._tokenizer)
4447
else:
4548
self._tokenizer = None
4649
logger.debug("No tokenizer provided")
4750

51+
if self._mode == "async":
52+
self._thread = threading.Thread(target=self._populate_queue, daemon=True)
53+
self._thread.start()
54+
logger.info(
55+
"RequestGenerator started in async mode with queue size: {}",
56+
self._async_queue_size,
57+
)
58+
4859
def __repr__(self) -> str:
4960
"""
5061
Return a string representation of the RequestGenerator.
@@ -72,7 +83,7 @@ def __iter__(self) -> Iterator[TextGenerationRequest]:
7283
item = self._queue.get_nowait()
7384
self._queue.task_done()
7485
yield item
75-
except asyncio.QueueEmpty:
86+
except Empty:
7687
continue
7788
else:
7889
while not self._stop_event.is_set():
@@ -118,46 +129,31 @@ def create_item(self) -> TextGenerationRequest:
118129
"""
119130
raise NotImplementedError()
120131

121-
def start(self):
122-
"""
123-
Start the background task that populates the queue.
124-
"""
125-
if self.mode == "async":
126-
try:
127-
loop = asyncio.get_running_loop()
128-
logger.info("Using existing event loop")
129-
except RuntimeError:
130-
raise RuntimeError("No running event loop found for async mode")
131-
132-
loop.call_soon_threadsafe(
133-
lambda: asyncio.create_task(self._populate_queue())
134-
)
135-
logger.info(
136-
f"RequestGenerator started in async mode with queue size: "
137-
f"{self._async_queue_size}"
138-
)
139-
else:
140-
logger.info("RequestGenerator started in sync mode")
141-
142132
def stop(self):
143133
"""
144134
Stop the background task that populates the queue.
145135
"""
146136
logger.info("Stopping RequestGenerator...")
147137
self._stop_event.set()
138+
if self._mode == "async":
139+
self._thread.join()
148140
logger.info("RequestGenerator stopped")
149141

150-
async def _populate_queue(self):
142+
def _populate_queue(self):
151143
"""
152144
Populate the request queue in the background.
153145
"""
154146
while not self._stop_event.is_set():
155-
if self._queue.qsize() < self._async_queue_size:
156-
item = self.create_item()
157-
await self._queue.put(item)
158-
logger.debug(
159-
f"Item added to queue. Current queue size: {self._queue.qsize()}"
160-
)
161-
else:
162-
await asyncio.sleep(0.1)
147+
try:
148+
if self._queue.qsize() < self._async_queue_size:
149+
item = self.create_item()
150+
self._queue.put(item, timeout=STANDARD_SLEEP_INTERVAL)
151+
logger.debug(
152+
"Item added to queue. Current queue size: {}",
153+
self._queue.qsize(),
154+
)
155+
else:
156+
time.sleep(STANDARD_SLEEP_INTERVAL)
157+
except Full:
158+
continue
163159
logger.info("RequestGenerator stopped populating queue")

src/guidellm/utils/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
from .constants import PREFERRED_DATA_COLUMNS, PREFERRED_DATA_SPLITS
1+
from .constants import (
2+
PREFERRED_DATA_COLUMNS,
3+
PREFERRED_DATA_SPLITS,
4+
STANDARD_SLEEP_INTERVAL,
5+
)
26

3-
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS"]
7+
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]

src/guidellm/utils/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS"]
1+
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]
22

33

44
PREFERRED_DATA_COLUMNS = [
@@ -15,3 +15,5 @@
1515
]
1616

1717
PREFERRED_DATA_SPLITS = ["test", "validation", "train"]
18+
19+
STANDARD_SLEEP_INTERVAL = 0.1
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import pytest
2+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
3+
4+
from guidellm.core.request import TextGenerationRequest
5+
from guidellm.request.base import RequestGenerator
6+
7+
8+
class TestRequestGenerator(RequestGenerator):
9+
def create_item(self) -> TextGenerationRequest:
10+
return TextGenerationRequest(prompt="Test prompt")
11+
12+
13+
@pytest.mark.smoke
14+
def test_request_generator_with_hf_tokenizer():
15+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
16+
generator = TestRequestGenerator(tokenizer=tokenizer)
17+
assert generator.tokenizer == tokenizer
18+
19+
20+
@pytest.mark.smoke
21+
def test_request_generator_with_string_tokenizer():
22+
generator = TestRequestGenerator(tokenizer="bert-base-uncased")
23+
assert isinstance(generator.tokenizer, PreTrainedTokenizerBase)
24+
assert generator.tokenizer.name_or_path == "bert-base-uncased"

tests/unit/request/test_base.py

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import time
2+
from unittest.mock import Mock, patch
3+
14
import pytest
25

36
from guidellm.core.request import TextGenerationRequest
@@ -10,15 +13,28 @@ def create_item(self) -> TextGenerationRequest:
1013

1114

1215
@pytest.mark.smoke
13-
def test_request_generator_sync():
16+
def test_request_generator_sync_constructor():
1417
generator = TestRequestGenerator(mode="sync")
1518
assert generator.mode == "sync"
19+
assert generator.async_queue_size == 50 # Default value
1620
assert generator.tokenizer is None
1721

22+
23+
@pytest.mark.smoke
24+
def test_request_generator_async_constructor():
25+
generator = TestRequestGenerator(mode="async", async_queue_size=10)
26+
assert generator.mode == "async"
27+
assert generator.async_queue_size == 10
28+
assert generator.tokenizer is None
29+
generator.stop()
30+
31+
32+
@pytest.mark.smoke
33+
def test_request_generator_sync_iter():
34+
generator = TestRequestGenerator(mode="sync")
1835
items = []
1936
for item in generator:
2037
items.append(item)
21-
2238
if len(items) == 5:
2339
break
2440

@@ -27,28 +43,30 @@ def test_request_generator_sync():
2743

2844

2945
@pytest.mark.smoke
30-
@pytest.mark.asyncio
31-
def test_request_generator_async():
32-
generator = TestRequestGenerator(mode="async", async_queue_size=10)
33-
assert generator.mode == "async"
34-
assert generator.async_queue_size == 10
35-
assert generator.tokenizer is None
36-
37-
generator.start()
38-
46+
def test_request_generator_async_iter():
47+
generator = TestRequestGenerator(mode="async")
3948
items = []
4049
for item in generator:
4150
items.append(item)
42-
4351
if len(items) == 5:
4452
break
4553

4654
generator.stop()
47-
assert generator._stop_event.is_set()
48-
4955
assert len(items) == 5
5056
assert items[0].prompt == "Test prompt"
51-
assert items[-1].prompt == "Test prompt"
57+
58+
59+
@pytest.mark.regression
60+
def test_request_generator_with_mock_tokenizer():
61+
mock_tokenizer = Mock()
62+
generator = TestRequestGenerator(tokenizer=mock_tokenizer)
63+
assert generator.tokenizer == mock_tokenizer
64+
65+
with patch("guidellm.request.base.AutoTokenizer") as MockAutoTokenizer:
66+
MockAutoTokenizer.from_pretrained.return_value = mock_tokenizer
67+
generator = TestRequestGenerator(tokenizer="mock-tokenizer")
68+
assert generator.tokenizer == mock_tokenizer
69+
MockAutoTokenizer.from_pretrained.assert_called_with("mock-tokenizer")
5270

5371

5472
@pytest.mark.regression
@@ -57,3 +75,58 @@ def test_request_generator_repr():
5775
assert repr(generator) == (
5876
"RequestGenerator(mode=sync, async_queue_size=100, tokenizer=None)"
5977
)
78+
79+
80+
@pytest.mark.regression
81+
def test_request_generator_create_item_not_implemented():
82+
with pytest.raises(TypeError):
83+
84+
class IncompleteRequestGenerator(RequestGenerator):
85+
pass
86+
87+
IncompleteRequestGenerator()
88+
89+
class IncompleteCreateItemGenerator(RequestGenerator):
90+
def create_item(self):
91+
super().create_item()
92+
93+
generator = IncompleteCreateItemGenerator()
94+
with pytest.raises(NotImplementedError):
95+
generator.create_item()
96+
97+
98+
@pytest.mark.regression
99+
def test_request_generator_iter_calls_create_item():
100+
generator = TestRequestGenerator(mode="sync")
101+
generator.create_item = Mock(
102+
return_value=TextGenerationRequest(prompt="Mock prompt")
103+
)
104+
105+
items = []
106+
for item in generator:
107+
items.append(item)
108+
if len(items) == 5:
109+
break
110+
111+
assert generator._queue.qsize() == 0
112+
generator.create_item.assert_called()
113+
114+
115+
@pytest.mark.regression
116+
def test_request_generator_async_iter_calls_create_item():
117+
generator = TestRequestGenerator(mode="sync")
118+
generator.create_item = Mock(
119+
return_value=TextGenerationRequest(prompt="Mock prompt")
120+
)
121+
122+
items = []
123+
for item in generator:
124+
items.append(item)
125+
if len(items) == 5:
126+
break
127+
128+
generator.stop()
129+
stop_size = generator._queue.qsize()
130+
time.sleep(0.1)
131+
assert generator._queue.qsize() == stop_size
132+
generator.create_item.assert_called()

0 commit comments

Comments
 (0)