Skip to content

Refactor RequestGenerator to use threading and update test suite #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions src/guidellm/request/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import threading
import time
from abc import ABC, abstractmethod
from queue import Empty, Full, Queue
from typing import Iterator, Optional, Union

from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizer

from guidellm.core.request import TextGenerationRequest
from guidellm.utils import STANDARD_SLEEP_INTERVAL

__all__ = ["RequestGenerator"]

Expand All @@ -31,20 +34,28 @@ def __init__(
):
self._async_queue_size = async_queue_size
self._mode = mode
self._queue = asyncio.Queue(maxsize=async_queue_size)
self._stop_event = asyncio.Event()
self._queue = Queue(maxsize=async_queue_size)
self._stop_event = threading.Event()

if tokenizer is not None:
self._tokenizer = (
AutoTokenizer.from_pretrained(tokenizer)
if isinstance(tokenizer, str)
else tokenizer
)
logger.info(f"Tokenizer initialized: {self._tokenizer}")
logger.info("Tokenizer initialized: {}", self._tokenizer)
else:
self._tokenizer = None
logger.debug("No tokenizer provided")

if self._mode == "async":
self._thread = threading.Thread(target=self._populate_queue, daemon=True)
self._thread.start()
logger.info(
"RequestGenerator started in async mode with queue size: {}",
self._async_queue_size,
)

def __repr__(self) -> str:
"""
Return a string representation of the RequestGenerator.
Expand Down Expand Up @@ -72,7 +83,7 @@ def __iter__(self) -> Iterator[TextGenerationRequest]:
item = self._queue.get_nowait()
self._queue.task_done()
yield item
except asyncio.QueueEmpty:
except Empty:
continue
else:
while not self._stop_event.is_set():
Expand Down Expand Up @@ -118,46 +129,31 @@ def create_item(self) -> TextGenerationRequest:
"""
raise NotImplementedError()

def start(self):
"""
Start the background task that populates the queue.
"""
if self.mode == "async":
try:
loop = asyncio.get_running_loop()
logger.info("Using existing event loop")
except RuntimeError:
raise RuntimeError("No running event loop found for async mode")

loop.call_soon_threadsafe(
lambda: asyncio.create_task(self._populate_queue())
)
logger.info(
f"RequestGenerator started in async mode with queue size: "
f"{self._async_queue_size}"
)
else:
logger.info("RequestGenerator started in sync mode")

def stop(self):
"""
Stop the background task that populates the queue.
"""
logger.info("Stopping RequestGenerator...")
self._stop_event.set()
if self._mode == "async":
self._thread.join()
logger.info("RequestGenerator stopped")

async def _populate_queue(self):
def _populate_queue(self):
"""
Populate the request queue in the background.
"""
while not self._stop_event.is_set():
if self._queue.qsize() < self._async_queue_size:
item = self.create_item()
await self._queue.put(item)
logger.debug(
f"Item added to queue. Current queue size: {self._queue.qsize()}"
)
else:
await asyncio.sleep(0.1)
try:
if self._queue.qsize() < self._async_queue_size:
item = self.create_item()
self._queue.put(item, timeout=STANDARD_SLEEP_INTERVAL)
logger.debug(
"Item added to queue. Current queue size: {}",
self._queue.qsize(),
)
else:
time.sleep(STANDARD_SLEEP_INTERVAL)
except Full:
continue
logger.info("RequestGenerator stopped populating queue")
8 changes: 6 additions & 2 deletions src/guidellm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .constants import PREFERRED_DATA_COLUMNS, PREFERRED_DATA_SPLITS
from .constants import (
PREFERRED_DATA_COLUMNS,
PREFERRED_DATA_SPLITS,
STANDARD_SLEEP_INTERVAL,
)

__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS"]
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]
4 changes: 3 additions & 1 deletion src/guidellm/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS"]
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]


PREFERRED_DATA_COLUMNS = [
Expand All @@ -15,3 +15,5 @@
]

PREFERRED_DATA_SPLITS = ["test", "validation", "train"]

STANDARD_SLEEP_INTERVAL = 0.1
24 changes: 24 additions & 0 deletions tests/integration/request/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from guidellm.core.request import TextGenerationRequest
from guidellm.request.base import RequestGenerator


class TestRequestGenerator(RequestGenerator):
def create_item(self) -> TextGenerationRequest:
return TextGenerationRequest(prompt="Test prompt")


@pytest.mark.smoke
def test_request_generator_with_hf_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
generator = TestRequestGenerator(tokenizer=tokenizer)
assert generator.tokenizer == tokenizer


@pytest.mark.smoke
def test_request_generator_with_string_tokenizer():
generator = TestRequestGenerator(tokenizer="bert-base-uncased")
assert isinstance(generator.tokenizer, PreTrainedTokenizerBase)
assert generator.tokenizer.name_or_path == "bert-base-uncased"
103 changes: 88 additions & 15 deletions tests/unit/request/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import time
from unittest.mock import Mock, patch

import pytest

from guidellm.core.request import TextGenerationRequest
Expand All @@ -10,15 +13,28 @@ def create_item(self) -> TextGenerationRequest:


@pytest.mark.smoke
def test_request_generator_sync():
def test_request_generator_sync_constructor():
generator = TestRequestGenerator(mode="sync")
assert generator.mode == "sync"
assert generator.async_queue_size == 50 # Default value
assert generator.tokenizer is None


@pytest.mark.smoke
def test_request_generator_async_constructor():
generator = TestRequestGenerator(mode="async", async_queue_size=10)
assert generator.mode == "async"
assert generator.async_queue_size == 10
assert generator.tokenizer is None
generator.stop()


@pytest.mark.smoke
def test_request_generator_sync_iter():
generator = TestRequestGenerator(mode="sync")
items = []
for item in generator:
items.append(item)

if len(items) == 5:
break

Expand All @@ -27,28 +43,30 @@ def test_request_generator_sync():


@pytest.mark.smoke
@pytest.mark.asyncio
def test_request_generator_async():
generator = TestRequestGenerator(mode="async", async_queue_size=10)
assert generator.mode == "async"
assert generator.async_queue_size == 10
assert generator.tokenizer is None

generator.start()

def test_request_generator_async_iter():
generator = TestRequestGenerator(mode="async")
items = []
for item in generator:
items.append(item)

if len(items) == 5:
break

generator.stop()
assert generator._stop_event.is_set()

assert len(items) == 5
assert items[0].prompt == "Test prompt"
assert items[-1].prompt == "Test prompt"


@pytest.mark.regression
def test_request_generator_with_mock_tokenizer():
mock_tokenizer = Mock()
generator = TestRequestGenerator(tokenizer=mock_tokenizer)
assert generator.tokenizer == mock_tokenizer

with patch("guidellm.request.base.AutoTokenizer") as MockAutoTokenizer:
MockAutoTokenizer.from_pretrained.return_value = mock_tokenizer
generator = TestRequestGenerator(tokenizer="mock-tokenizer")
assert generator.tokenizer == mock_tokenizer
MockAutoTokenizer.from_pretrained.assert_called_with("mock-tokenizer")


@pytest.mark.regression
Expand All @@ -57,3 +75,58 @@ def test_request_generator_repr():
assert repr(generator) == (
"RequestGenerator(mode=sync, async_queue_size=100, tokenizer=None)"
)


@pytest.mark.regression
def test_request_generator_create_item_not_implemented():
with pytest.raises(TypeError):

class IncompleteRequestGenerator(RequestGenerator):
pass

IncompleteRequestGenerator()

class IncompleteCreateItemGenerator(RequestGenerator):
def create_item(self):
super().create_item()

generator = IncompleteCreateItemGenerator()
with pytest.raises(NotImplementedError):
generator.create_item()


@pytest.mark.regression
def test_request_generator_iter_calls_create_item():
generator = TestRequestGenerator(mode="sync")
generator.create_item = Mock(
return_value=TextGenerationRequest(prompt="Mock prompt")
)

items = []
for item in generator:
items.append(item)
if len(items) == 5:
break

assert generator._queue.qsize() == 0
generator.create_item.assert_called()


@pytest.mark.regression
def test_request_generator_async_iter_calls_create_item():
generator = TestRequestGenerator(mode="sync")
generator.create_item = Mock(
return_value=TextGenerationRequest(prompt="Mock prompt")
)

items = []
for item in generator:
items.append(item)
if len(items) == 5:
break

generator.stop()
stop_size = generator._queue.qsize()
time.sleep(0.1)
assert generator._queue.qsize() == stop_size
generator.create_item.assert_called()