Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/huggingface_hub/utils/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ async def async_hf_request_event_hook(request: httpx.Request) -> None:
return hf_request_event_hook(request)


async def async_hf_response_event_hook(response: httpx.Response) -> None:
if response.status_code >= 400:
# If response will raise, read content from stream to have it available when raising the exception
# If content-length is not set or is too large, skip reading the content to avoid OOM
if "Content-length" in response.headers:
try:
length = int(response.headers["Content-length"])
except ValueError:
return

if length < 1_000_000:
await response.aread()


def default_client_factory() -> httpx.Client:
"""
Factory function to create a `httpx.Client` with the default transport.
Expand All @@ -125,7 +139,7 @@ def default_async_client_factory() -> httpx.AsyncClient:
Factory function to create a `httpx.AsyncClient` with the default transport.
"""
return httpx.AsyncClient(
event_hooks={"request": [async_hf_request_event_hook]},
event_hooks={"request": [async_hf_request_event_hook], "response": [async_hf_response_event_hook]},
follow_redirects=True,
timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0),
)
Expand Down Expand Up @@ -626,8 +640,16 @@ def _format(error_type: type[HfHubHTTPError], custom_message: str, response: htt
try:
data = response.json()
except httpx.ResponseNotRead:
response.read() # In case of streaming response, we need to read the response first
data = response.json()
try:
response.read() # In case of streaming response, we need to read the response first
data = response.json()
except RuntimeError:
# In case of async streaming response, we can't read the stream here.
# In practice if user is using the default async client from `get_async_client`, the stream will have
# already been read in the async event hook `async_hf_response_event_hook`.
#
# Here, we are skipping reading the response to avoid RuntimeError but it happens only if async + stream + used httpx.AsyncClient directly.
data = {}

error = data.get("error")
if error is not None:
Expand Down
91 changes: 90 additions & 1 deletion tests/test_utils_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
import threading
import time
import unittest
from http.server import BaseHTTPRequestHandler, HTTPServer
from multiprocessing import Process, Queue
from typing import Generator, Optional
from unittest.mock import Mock, call, patch
from urllib.parse import urlparse
from uuid import UUID

import httpx
import pytest
from httpx import ConnectTimeout, HTTPError

from huggingface_hub.constants import ENDPOINT
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.errors import HfHubHTTPError, OfflineModeIsEnabled
from huggingface_hub.utils._http import (
_adjust_range_header,
default_client_factory,
fix_hf_endpoint_in_url,
get_async_session,
get_session,
hf_raise_for_status,
http_backoff,
set_client_factory,
)
Expand Down Expand Up @@ -378,3 +381,89 @@ async def test_async_client_get_request():
client = get_async_session()
response = await client.get("https://huggingface.co")
assert response.status_code == 200


class FakeServerHandler(BaseHTTPRequestHandler):
"""Fake server handler to test client behavior."""

def do_GET(self):
parsed = urlparse(self.path)

# Health check endpoint (always succeeds)
if parsed.path == "/health":
self._send_response(200, b"OK")
return

# Main endpoint (always fails with 500)
self._send_response(500, b"This is a 500 error")

def _send_response(self, status_code, body):
self.send_response(status_code)
self.send_header("Content-Type", "text/plain")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)


@pytest.fixture(scope="module", autouse=True)
def fake_server():
# Find a free port
host, port = "127.0.0.1", 8000
for port in range(port, 8100):
try:
server = HTTPServer((host, port), FakeServerHandler)
break
except OSError:
continue
else:
raise RuntimeError("Could not find a free port")

url = f"http://{host}:{port}"

# Start server in a separate thread and wait until it's ready
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()

for _ in range(1000): # up to 10 seconds
try:
if httpx.get(f"{url}/health", timeout=0.01).status_code == 200:
break
except httpx.HTTPError:
pass
time.sleep(0.01)
else:
server.shutdown()
raise RuntimeError("Fake server failed to start")

yield url
server.shutdown()


def _check_raise_status(response: httpx.Response):
"""Common assertions for 500 error tests."""
with pytest.raises(HfHubHTTPError) as exc_info:
hf_raise_for_status(response)
assert exc_info.value.response.status_code == 500
assert "This is a 500 error" in str(exc_info.value)


def test_raise_on_status_sync_non_stream(fake_server: str):
response = get_session().get(fake_server)
_check_raise_status(response)


def test_raise_on_status_sync_stream(fake_server: str):
with get_session().stream("GET", fake_server) as response:
_check_raise_status(response)


@pytest.mark.asyncio
async def test_raise_on_status_async_non_stream(fake_server: str):
response = await get_async_session().get(fake_server)
_check_raise_status(response)


@pytest.mark.asyncio
async def test_raise_on_status_async_stream(fake_server: str):
async with get_async_session().stream("GET", fake_server) as response:
_check_raise_status(response)
Loading