|
24 | 24 | from collections import defaultdict |
25 | 25 | from typing import Literal |
26 | 26 |
|
| 27 | +import requests |
27 | 28 | from huggingface_hub import HfApi |
28 | 29 | from inspect_ai import Epochs, Task, task |
29 | 30 | from inspect_ai import eval_set as inspect_ai_eval_set |
@@ -182,13 +183,31 @@ def _format_metric_cell(data: dict, col: str, metric: str, stderr_metric: str) - |
182 | 183 | return "-" |
183 | 184 |
|
184 | 185 |
|
| 186 | +def _get_huggingface_providers(model_id: str): |
| 187 | + model_id = model_id.replace("hf-inference-providers/", "").replace(":all", "") |
| 188 | + url = f"https://huggingface.co/api/models/{model_id}" |
| 189 | + params = {"expand[]": "inferenceProviderMapping"} |
| 190 | + response = requests.get(url, params=params) |
| 191 | + response.raise_for_status() # raise exception for HTTP errors |
| 192 | + data = response.json() |
| 193 | + # Extract provider mapping if available |
| 194 | + providers = data.get("inferenceProviderMapping", {}) |
| 195 | + |
| 196 | + live_providers = [] |
| 197 | + for provider, info in providers.items(): |
| 198 | + if info.get("status") == "live": |
| 199 | + live_providers.append(provider) |
| 200 | + |
| 201 | + return live_providers |
| 202 | + |
| 203 | + |
185 | 204 | HELP_PANEL_NAME_1 = "Modeling Parameters" |
186 | 205 | HELP_PANEL_NAME_2 = "Task Parameters" |
187 | 206 | HELP_PANEL_NAME_3 = "Connection and parallelization parameters" |
188 | 207 | HELP_PANEL_NAME_4 = "Logging parameters" |
189 | 208 |
|
190 | 209 |
|
191 | | -def eval( |
| 210 | +def eval( # noqa C901 |
192 | 211 | models: Annotated[list[str], Argument(help="Models to evaluate")], |
193 | 212 | tasks: Annotated[str, Argument(help="Tasks to evaluate")], |
194 | 213 | # model arguments |
@@ -404,6 +423,11 @@ def eval( |
404 | 423 | else: |
405 | 424 | model_args = {} |
406 | 425 |
|
| 426 | + for model in models: |
| 427 | + if model.split("/")[0] == "hf-inference-providers" and model.split(":")[-1] == "all": |
| 428 | + providers = _get_huggingface_providers(model) |
| 429 | + models = [f"{model.replace(':all', '')}:{provider}" for provider in providers] |
| 430 | + |
407 | 431 | success, logs = inspect_ai_eval_set( |
408 | 432 | inspect_ai_tasks, |
409 | 433 | model=models, |
|
0 commit comments