Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/gentrace/lib/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

ATTR_GENTRACE_SAMPLE_KEY = "gentrace.sample"

# Maximum allowed concurrency for eval_dataset
MAX_EVAL_DATASET_CONCURRENCY = 100

__all__ = [
"ANONYMOUS_SPAN_NAME",
"ATTR_GENTRACE_FN_ARGS_EVENT_NAME",
Expand All @@ -22,4 +25,5 @@
"ATTR_GENTRACE_TEST_CASE_ID",
"ATTR_GENTRACE_PIPELINE_ID",
"ATTR_GENTRACE_SAMPLE_KEY",
"MAX_EVAL_DATASET_CONCURRENCY",
]
6 changes: 4 additions & 2 deletions src/gentrace/lib/eval_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ATTR_GENTRACE_TEST_CASE_ID,
ATTR_GENTRACE_EXPERIMENT_ID,
ATTR_GENTRACE_TEST_CASE_NAME,
MAX_EVAL_DATASET_CONCURRENCY,
ATTR_GENTRACE_FN_ARGS_EVENT_NAME,
ATTR_GENTRACE_FN_OUTPUT_EVENT_NAME,
)
Expand Down Expand Up @@ -355,6 +356,7 @@ async def eval_dataset(
internally to TestCase).
max_concurrency (Optional[int]): Maximum number of test cases to run concurrently.
If None (default), all test cases run concurrently.
Maximum allowed value is 100.
For async functions, uses asyncio.Semaphore.
For sync functions, uses ThreadPoolExecutor.
show_progress_bar (Optional[bool]): Controls progress display during evaluation.
Expand Down Expand Up @@ -386,11 +388,11 @@ async def eval_dataset(
semaphore: Optional[asyncio.Semaphore] = None
if max_concurrency is not None and max_concurrency > 0:
# Throw exception if max_concurrency is very high
if max_concurrency > 30:
if max_concurrency > MAX_EVAL_DATASET_CONCURRENCY:
warning = GentraceWarnings.HighConcurrencyError(max_concurrency)
display_gentrace_warning(warning)
raise ValueError(
f"max_concurrency ({max_concurrency}) exceeds maximum allowed value of 30. Please use a value between 1 and 30."
f"max_concurrency ({max_concurrency}) exceeds maximum allowed value of {MAX_EVAL_DATASET_CONCURRENCY}. Please use a value between 1 and {MAX_EVAL_DATASET_CONCURRENCY}."
)

semaphore = asyncio.Semaphore(max_concurrency)
Expand Down
6 changes: 4 additions & 2 deletions src/gentrace/lib/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from rich.panel import Panel
from rich.console import Group, Console

from .constants import MAX_EVAL_DATASET_CONCURRENCY

# Global tracking of displayed warnings
_displayed_warnings: Set[str] = set()

Expand Down Expand Up @@ -252,10 +254,10 @@ def HighConcurrencyError(max_concurrency: int) -> GentraceWarning:
warning_id="GT_HighConcurrencyError",
title="High Concurrency Error",
message=[
f"max_concurrency of {max_concurrency} exceeds the maximum allowed value of 30.",
f"max_concurrency of {max_concurrency} exceeds the maximum allowed value of {MAX_EVAL_DATASET_CONCURRENCY}.",
"",
"Please use a lower value:",
f" eval_dataset(..., max_concurrency=30)",
f" eval_dataset(..., max_concurrency={MAX_EVAL_DATASET_CONCURRENCY})",
"",
"High concurrency can overwhelm your API providers and may lead to rate limiting.",
],
Expand Down
17 changes: 9 additions & 8 deletions tests/lib/test_eval_dataset_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import gentrace.lib.experiment_control as exp_ctrl
from gentrace import TestInput as GentraceTestInput, init, experiment, eval_dataset
from gentrace.types import TestCase as GentraceTestCase
from gentrace.lib.constants import MAX_EVAL_DATASET_CONCURRENCY
from gentrace.types.experiment import Experiment

# Use same pipeline ID as other tests
Expand Down Expand Up @@ -260,37 +261,37 @@ async def async_task(test_case: GentraceTestCase) -> Dict[str, Any]:
@pytest.mark.asyncio
@pytest.mark.filterwarnings("ignore:max_concurrency")
async def test_max_concurrency_exceeds_limit() -> None:
"""Test that max_concurrency > 30 raises ValueError."""
"""Test that max_concurrency > MAX_EVAL_DATASET_CONCURRENCY raises ValueError."""

async def async_task(_: GentraceTestCase) -> Dict[str, Any]:
"""Simple async task."""
return {"result": "ok"}

# Should raise ValueError when max_concurrency > 30
# Should raise ValueError when max_concurrency exceeds the limit
with pytest.raises(ValueError) as exc_info:
await eval_dataset(
data=lambda: create_test_data(5),
interaction=async_task,
max_concurrency=31,
max_concurrency=MAX_EVAL_DATASET_CONCURRENCY + 1,
)

assert "exceeds maximum allowed value of 30" in str(exc_info.value)
assert f"exceeds maximum allowed value of {MAX_EVAL_DATASET_CONCURRENCY}" in str(exc_info.value)

# Test with a much higher value
with pytest.raises(ValueError) as exc_info:
await eval_dataset(
data=lambda: create_test_data(5),
interaction=async_task,
max_concurrency=100,
max_concurrency=MAX_EVAL_DATASET_CONCURRENCY + 50,
)

assert "exceeds maximum allowed value of 30" in str(exc_info.value)
assert f"exceeds maximum allowed value of {MAX_EVAL_DATASET_CONCURRENCY}" in str(exc_info.value)

# Verify that max_concurrency=30 is still allowed (boundary test)
# Verify that max_concurrency=MAX_EVAL_DATASET_CONCURRENCY is still allowed (boundary test)
results = await eval_dataset(
data=lambda: create_test_data(5),
interaction=async_task,
max_concurrency=30,
max_concurrency=MAX_EVAL_DATASET_CONCURRENCY,
)

assert len(results) == 5
Loading