Skip to content

[Dataset]: Iterate through benchmark dataset once #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ exclude = ["venv", ".tox"]
# Check: https://mypy.readthedocs.io/en/latest/config_file.html#import-discovery
follow_imports = 'silent'

[[tool.mypy.overrides]]
module = ["datasets.*"]
ignore_missing_imports=true


[tool.ruff]
line-length = 88
Expand All @@ -122,6 +126,8 @@ ignore = [
"ISC001",
"TCH002",
"PLW1514", # allow Path.open without encoding
"RET505", # allow `else` blocks
"RET506" # allow `else` blocks

]
select = [
Expand Down
1 change: 1 addition & 0 deletions src/guidellm/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ async def make_request(
stream=True,
**request_args,
)

token_count = 0
async for chunk in stream:
choice = chunk.choices[0]
Expand Down
25 changes: 16 additions & 9 deletions src/guidellm/executor/profile_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Dict, Literal, Optional, Sequence, Union, get_args
from typing import Any, Dict, List, Literal, Optional, Sequence, Union, get_args

import numpy as np
from loguru import logger
from numpy._typing import NDArray
from pydantic import Field

from guidellm.config import settings
Expand Down Expand Up @@ -190,12 +191,14 @@ def next(self, current_report: TextGenerationBenchmarkReport) -> Optional[Profil
elif self.mode == "sweep":
profile = self.create_sweep_profile(
self.generated_count,
sync_benchmark=current_report.benchmarks[0]
if current_report.benchmarks
else None,
throughput_benchmark=current_report.benchmarks[1]
if len(current_report.benchmarks) > 1
else None,
sync_benchmark=(
current_report.benchmarks[0] if current_report.benchmarks else None
),
throughput_benchmark=(
current_report.benchmarks[1]
if len(current_report.benchmarks) > 1
else None
),
)
else:
err = ValueError(f"Invalid mode: {self.mode}")
Expand Down Expand Up @@ -333,11 +336,15 @@ def create_sweep_profile(

min_rate = sync_benchmark.completed_request_rate
max_rate = throughput_benchmark.completed_request_rate
intermediate_rates = list(
intermediate_rates: List[NDArray] = list(
np.linspace(min_rate, max_rate, settings.num_sweep_profiles + 1)
)[1:]

return Profile(
load_gen_mode="constant",
load_gen_rate=intermediate_rates[index - 2],
load_gen_rate=(
float(load_gen_rate)
if (load_gen_rate := intermediate_rates[index - 2])
else 1.0 # the fallback value
),
)
17 changes: 11 additions & 6 deletions src/guidellm/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Literal, Optional, get_args
from typing import Literal, Optional, Union, get_args

import click
from loguru import logger
Expand All @@ -13,7 +13,7 @@
TransformersDatasetRequestGenerator,
)
from guidellm.request.base import RequestGenerator
from guidellm.utils import BenchmarkReportProgress
from guidellm.utils import BenchmarkReportProgress, cli_params

__all__ = ["generate_benchmark_report"]

Expand Down Expand Up @@ -120,7 +120,7 @@
)
@click.option(
"--max-requests",
type=int,
type=cli_params.MAX_REQUESTS,
default=None,
help=(
"The maximum number of requests for each benchmark run. "
Expand Down Expand Up @@ -161,7 +161,7 @@ def generate_benchmark_report_cli(
rate_type: ProfileGenerationMode,
rate: Optional[float],
max_seconds: Optional[int],
max_requests: Optional[int],
max_requests: Union[Literal["dataset"], int, None],
output_path: str,
enable_continuous_refresh: bool,
):
Expand Down Expand Up @@ -194,7 +194,7 @@ def generate_benchmark_report(
rate_type: ProfileGenerationMode,
rate: Optional[float],
max_seconds: Optional[int],
max_requests: Optional[int],
max_requests: Union[Literal["dataset"], int, None],
output_path: str,
cont_refresh_table: bool,
) -> GuidanceReport:
Expand Down Expand Up @@ -256,13 +256,18 @@ def generate_benchmark_report(
else:
raise ValueError(f"Unknown data type: {data_type}")

if data_type == "emulated" and max_requests == "dataset":
raise ValueError("Cannot use 'dataset' for emulated data")

# Create executor
executor = Executor(
backend=backend_inst,
request_generator=request_generator,
mode=rate_type,
rate=rate if rate_type in ("constant", "poisson") else None,
max_number=max_requests,
max_number=(
len(request_generator) if max_requests == "dataset" else max_requests
),
max_duration=max_seconds,
)

Expand Down
24 changes: 15 additions & 9 deletions src/guidellm/request/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,21 @@ def __iter__(self) -> Iterator[TextGenerationRequest]:
while not self._stop_event.is_set():
yield self.create_item()

@abstractmethod
def __len__(self) -> int:
"""
Abstract method to get the length of the collection to be generated.
"""

@abstractmethod
def create_item(self) -> TextGenerationRequest:
"""
Abstract method to create a new result request item.

:return: A new result request.
:rtype: TextGenerationRequest
"""

@property
def type_(self) -> str:
"""
Expand Down Expand Up @@ -155,15 +170,6 @@ def async_queue_size(self) -> int:
"""
return self._async_queue_size

@abstractmethod
def create_item(self) -> TextGenerationRequest:
"""
Abstract method to create a new result request item.

:return: A new result request.
:rtype: TextGenerationRequest
"""

def stop(self):
"""
Stop the background task that populates the queue.
Expand Down
6 changes: 6 additions & 0 deletions src/guidellm/request/emulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,12 @@ def __init__(
async_queue_size=async_queue_size,
)

def __len__(self) -> int:
raise NotImplementedError(
"Can't get the length of the emulated dataset. "
"Check the `--data-type` CLI parameter."
)

def create_item(self) -> TextGenerationRequest:
"""
Create a new text generation request item from the data.
Expand Down
7 changes: 7 additions & 0 deletions src/guidellm/request/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def __init__(
async_queue_size=async_queue_size,
)

def __len__(self) -> int:
"""
Return the number of text lines.
"""

return len(self._data)

def create_item(self) -> TextGenerationRequest:
"""
Create a new result request item from the data.
Expand Down
17 changes: 10 additions & 7 deletions src/guidellm/request/transformers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from pathlib import Path
from typing import Optional, Union

from datasets import ( # type: ignore # noqa: PGH003
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
)
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from loguru import logger
from transformers import PreTrainedTokenizer # type: ignore # noqa: PGH003

Expand Down Expand Up @@ -57,7 +52,9 @@ def __init__(
self._column = column
self._kwargs = kwargs

self._hf_dataset = load_transformers_dataset(dataset, split=split, **kwargs)
self._hf_dataset: Union[Dataset, IterableDataset] = load_transformers_dataset(
dataset, split=split, **kwargs
)
self._hf_column = resolve_transformers_dataset_column(
self._hf_dataset, column=column
)
Expand All @@ -73,6 +70,12 @@ def __init__(
async_queue_size=async_queue_size,
)

def __len__(self) -> int:
if not isinstance(self._hf_dataset, Dataset):
raise ValueError("Can't get dataset size for IterableDataset object")
else:
return len(self._hf_dataset)

def create_item(self) -> TextGenerationRequest:
"""
Create a new result request item from the dataset.
Expand Down
34 changes: 34 additions & 0 deletions src/guidellm/utils/cli_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This module includes custom CLI parameters for the `click` package.
"""

from typing import Any, Optional

from click import Context, Parameter, ParamType

__all__ = ["MAX_REQUESTS"]


class MaxRequestsType(ParamType):
"""
Catch the `dataset` string parameter to determine the behavior of the Scheduler.
"""

name = "max_requests"

def convert(
self, value: Any, param: Optional[Parameter], ctx: Optional[Context]
) -> Any:
if isinstance(value, int):
return value

try:
return int(value)
except ValueError:
if value == "dataset":
return value
else:
self.fail(f"{value} is not a valid integer or 'dataset'", param, ctx)


MAX_REQUESTS = MaxRequestsType()
9 changes: 6 additions & 3 deletions src/guidellm/utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def update_benchmark(
:type req_per_sec: float
:raises ValueError: If trying to update a completed benchmark.
"""

if self.benchmark_tasks_completed[index]:
err = ValueError(f"Benchmark {index} already completed")
logger.error("Error updating benchmark: {}", err)
Expand All @@ -162,9 +163,11 @@ def update_benchmark(
total=completed_total,
completed=completed_count if not completed else completed_total,
req_per_sec=(f"{req_per_sec:.2f}" if req_per_sec else "#.##"),
start_time_str=datetime.fromtimestamp(start_time).strftime("%H:%M:%S")
if start_time
else "--:--:--",
start_time_str=(
datetime.fromtimestamp(start_time).strftime("%H:%M:%S")
if start_time
else "--:--:--"
),
)
logger.debug(
"Updated benchmark task at index {}: {}% complete",
Expand Down
2 changes: 1 addition & 1 deletion src/guidellm/utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def load_text_lines(
format_ = "txt"

# load the data if it's a path or URL
if isinstance(data, Path) or (isinstance(data, str) and data.startswith("http")):
if isinstance(data, (Path, str)):
data = load_text(data, encoding=encoding)
data = clean_text(data)

Expand Down
3 changes: 3 additions & 0 deletions tests/dummy/services/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ def __init__(

def create_item(self) -> TextGenerationRequest:
return TextGenerationRequest(prompt="Test prompt")

def __len__(self) -> int:
raise NotImplementedError
Empty file added tests/unit/cli/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions tests/unit/cli/test_custom_type_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from click import BadParameter

from guidellm.utils import cli_params


@pytest.fixture()
def max_requests_param_type():
return cli_params.MaxRequestsType()


def test_valid_integer_input(max_requests_param_type):
assert max_requests_param_type.convert(10, None, None) == 10
assert max_requests_param_type.convert("42", None, None) == 42


def test_valid_dataset_input(max_requests_param_type):
assert max_requests_param_type.convert("dataset", None, None) == "dataset"


def test_invalid_string_input(max_requests_param_type):
with pytest.raises(BadParameter):
max_requests_param_type.convert("invalid", None, None)


def test_invalid_float_input(max_requests_param_type):
with pytest.raises(BadParameter):
max_requests_param_type.convert("10.5", None, None)


def test_invalid_non_numeric_string_input(max_requests_param_type):
with pytest.raises(BadParameter):
max_requests_param_type.convert("abc", None, None)


def test_invalid_mixed_string_input(max_requests_param_type):
with pytest.raises(BadParameter):
max_requests_param_type.convert("123abc", None, None)
Loading