Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/inference_endpoint/config/runtime_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def _from_config_default(
"metric_target": metrics.Throughput(effective_qps),
"reported_metrics": [metrics.Throughput(effective_qps)],
"min_duration_ms": runtime_cfg.min_duration_ms,
"max_duration_ms": runtime_cfg.max_duration_ms,
"max_duration_ms": None
if runtime_cfg.max_duration_ms == 0
else runtime_cfg.max_duration_ms,
"n_samples_from_dataset": dataloader_num_samples,
"n_samples_to_issue": runtime_cfg.n_samples_to_issue, # From config (CLI --num-samples or YAML)
"min_sample_count": 1,
Expand Down
6 changes: 4 additions & 2 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ class RuntimeConfig(BaseModel):
),
] = Field(600000, ge=0)
max_duration_ms: int = Field(
1800000, ge=0, description="Maximum test duration in ms"
0,
ge=0,
description="Maximum test duration in ms (0 for no limit)",
)

@field_validator("min_duration_ms", "max_duration_ms", mode="before")
Expand All @@ -338,7 +340,7 @@ def _parse_duration_suffix(cls, v: object) -> object:

@model_validator(mode="after")
def _validate_durations(self) -> Self:
if self.max_duration_ms < self.min_duration_ms:
if self.max_duration_ms != 0 and self.max_duration_ms < self.min_duration_ms:
raise ValueError(
f"max_duration_ms ({self.max_duration_ms}) must be >= "
f"min_duration_ms ({self.min_duration_ms})"
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/config/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,17 @@ def test_max_lt_min_duration_rejected(self):
},
)

@pytest.mark.unit
def test_max_duration_below_zero_rejected(self):
with pytest.raises(ValueError, match="greater than or equal to 0"):
BenchmarkConfig(
type=TestType.OFFLINE,
model_params={"name": "M"},
endpoint_config={"endpoints": ["http://x"]},
datasets=[{"path": "D"}],
settings={"runtime": {"max_duration_ms": -1}},
)

@pytest.mark.unit
def test_submission_bad_benchmark_mode(self):
with pytest.raises(ValueError, match="benchmark_mode"):
Expand Down Expand Up @@ -343,6 +354,20 @@ def test_to_yaml_file(self, tmp_path):
loaded = BenchmarkConfig.from_yaml_file(out)
assert loaded.model_params.name == "M"

@pytest.mark.unit
def test_max_duration_zero_converts_to_none_in_runtime_settings(self):
from inference_endpoint.config.runtime_settings import RuntimeSettings

config = BenchmarkConfig(
type=TestType.OFFLINE,
model_params={"name": "M"},
endpoint_config={"endpoints": ["http://x"]},
datasets=[{"path": "D"}],
settings={"runtime": {"max_duration_ms": 0}},
)
rt = RuntimeSettings.from_config(config, dataloader_num_samples=100)
assert rt.max_duration_ms is None

@pytest.mark.unit
def test_from_yaml_file_not_found(self):
from pathlib import Path
Expand Down
Loading