Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/huggingface_hub/inference/_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from huggingface_hub.utils import logging

from ._common import TaskProviderHelper, _fetch_inference_provider_mapping
from ._common import AutoRouterConversationalTask, TaskProviderHelper, _fetch_inference_provider_mapping
from .black_forest_labs import BlackForestLabsTextToImageTask
from .cerebras import CerebrasConversationalTask
from .cohere import CohereConversationalTask
Expand Down Expand Up @@ -71,6 +71,8 @@

PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]]

CONVERSATIONAL_AUTO_ROUTER = AutoRouterConversationalTask()

PROVIDERS: dict[PROVIDER_T, dict[str, TaskProviderHelper]] = {
"black-forest-labs": {
"text-to-image": BlackForestLabsTextToImageTask(),
Expand Down Expand Up @@ -201,13 +203,19 @@ def get_provider_helper(

if provider is None:
logger.info(
"Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
"No provider specified for task `conversational`. Defaulting to server-side auto routing."
if task == "conversational"
else "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers."
)
provider = "auto"

if provider == "auto":
if model is None:
raise ValueError("Specifying a model is required when provider is 'auto'")
if task == "conversational":
# Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
return CONVERSATIONAL_AUTO_ROUTER

provider_mapping = _fetch_inference_provider_mapping(model)
provider = next(iter(provider_mapping)).provider

Expand Down
38 changes: 38 additions & 0 deletions src/huggingface_hub/inference/_providers/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,44 @@ def _prepare_payload_as_dict(
return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})


class AutoRouterConversationalTask(BaseConversationalTask):
"""
Auto-router for conversational tasks.

We let the Hugging Face router select the best provider for the model, based on availability and user preferences.
This is a special case since the selection is done server-side (avoid 1 API call to fetch provider mapping).
"""

def __init__(self):
super().__init__(provider="auto", base_url="https://router.huggingface.co")

def _prepare_base_url(self, api_key: str) -> str:
"""Return the base URL to use for the request.

Usually not overwritten in subclasses."""
# Route to the proxy if the api_key is a HF TOKEN
if not api_key.startswith("hf_"):
raise ValueError("Cannot select auto-router when using non-Hugging Face API key.")
else:
return self.base_url # No `/auto` suffix in the URL

def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
"""
In auto-router, we don't need to fetch provider mapping info.
We just return a dummy mapping info with provider_id set to the HF model ID.
"""
if model is None:
raise ValueError("Please provide an HF model ID.")

return InferenceProviderMapping(
provider="auto",
hf_model_id=model,
providerId=model,
status="live",
task="conversational",
)


class BaseTextGenerationTask(TaskProviderHelper):
"""
Base class for text-generation (completion) tasks.
Expand Down
54 changes: 53 additions & 1 deletion tests/test_inference_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from huggingface_hub.inference._common import RequestParameters
from huggingface_hub.inference._providers import PROVIDERS, get_provider_helper
from huggingface_hub.inference._providers._common import (
AutoRouterConversationalTask,
BaseConversationalTask,
BaseTextGenerationTask,
TaskProviderHelper,
Expand Down Expand Up @@ -193,6 +194,47 @@ def test_prepare_url(self, mocker):
helper._prepare_route.assert_called_once_with("test-model", "sk_test_token")


class TestAutoRouterConversationalTask:
def test_properties(self):
helper = AutoRouterConversationalTask()
assert helper.provider == "auto"
assert helper.base_url == "https://router.huggingface.co"
assert helper.task == "conversational"

def test_prepare_mapping_info_is_fake(self):
helper = AutoRouterConversationalTask()
mapping_info = helper._prepare_mapping_info("test-model")
assert mapping_info.hf_model_id == "test-model"
assert mapping_info.provider_id == "test-model"
assert mapping_info.task == "conversational"
assert mapping_info.status == "live"

def test_prepare_request(self):
helper = AutoRouterConversationalTask()

request = helper.prepare_request(
inputs=[{"role": "user", "content": "Hello!"}],
parameters={"model": "test-model", "frequency_penalty": 1.0},
headers={},
model="test-model",
api_key="hf_test_token",
)

# Use auto-router URL
assert request.url == "https://router.huggingface.co/v1/chat/completions"

# The rest is the expected request for a Chat Completion API
assert request.headers["authorization"] == "Bearer hf_test_token"
assert request.json == {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "test-model",
"frequency_penalty": 1.0,
}
assert request.task == "conversational"
assert request.model == "test-model"
assert request.data is None


class TestBlackForestLabsProvider:
def test_prepare_headers_bfl_key(self):
helper = BlackForestLabsTextToImageTask()
Expand Down Expand Up @@ -1670,7 +1712,7 @@ def test_filter_none(data: dict, expected: dict):
assert filter_none(data) == expected


def test_get_provider_helper_auto(mocker):
def test_get_provider_helper_auto_non_conversational(mocker):
"""Test the 'auto' provider selection logic."""

mock_provider_a_helper = mocker.Mock(spec=TaskProviderHelper)
Expand All @@ -1692,3 +1734,13 @@ def test_get_provider_helper_auto(mocker):

PROVIDERS.pop("provider-a", None)
PROVIDERS.pop("provider-b", None)


def test_get_provider_helper_auto_conversational():
"""Test the 'auto' provider selection logic for conversational task.

In practice, no HTTP call is made to the Hub because routing is done server-side.
"""
helper = get_provider_helper(provider="auto", task="conversational", model="test-model")

assert isinstance(helper, AutoRouterConversationalTask)
Loading