|
14 | 14 | from collections.abc import AsyncGenerator, Awaitable, Callable |
15 | 15 | from dataclasses import dataclass, field |
16 | 16 | from typing import Protocol |
17 | | -from urllib.parse import urlencode, urljoin, urlparse |
| 17 | +from urllib.parse import parse_qs, urlencode, urljoin, urlparse |
18 | 18 |
|
19 | 19 | import anyio |
20 | 20 | import httpx |
@@ -427,14 +427,43 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req |
427 | 427 | "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} |
428 | 428 | ) |
429 | 429 |
|
| 430 | + def _parse_content_type(self, content_type: str) -> tuple[str, dict[str, str]]: |
| 431 | + """Parse Content-Type header into media type and parameters.""" |
| 432 | + parts = content_type.split(";") |
| 433 | + media_type = parts[0].strip() |
| 434 | + |
| 435 | + params: dict[str, str] = {} |
| 436 | + for part in parts[1:]: |
| 437 | + if "=" in part: |
| 438 | + key, value = part.split("=", 1) |
| 439 | + params[key.strip()] = value.strip() |
| 440 | + |
| 441 | + return media_type, params |
| 442 | + |
430 | 443 | async def _handle_token_response(self, response: httpx.Response) -> None: |
431 | 444 | """Handle token exchange response.""" |
432 | 445 | if response.status_code != 200: |
433 | 446 | raise OAuthTokenError(f"Token exchange failed: {response.status_code}") |
434 | 447 |
|
| 448 | + content_type = response.headers.get("Content-Type") |
| 449 | + if content_type is None: |
| 450 | + raise OAuthTokenError(f"Token exchange failed: Missing 'Content-Type' response header") |
| 451 | + |
| 452 | + media_type, params = self._parse_content_type(content_type) |
| 453 | + if media_type not in ("application/json", "application/x-www-form-urlencoded"): |
| 454 | + raise OAuthTokenError(f"Token exchange failed: Unexpected token response content type {media_type}") |
| 455 | + |
435 | 456 | try: |
436 | 457 | content = await response.aread() |
437 | | - token_response = OAuthToken.model_validate_json(content) |
| 458 | + if media_type == "application/json": |
| 459 | + token_response = OAuthToken.model_validate_json(content) |
| 460 | + else: |
| 461 | + charset = params.get("charset", "utf-8") |
| 462 | + parsed = parse_qs(content.decode(charset)) |
| 463 | + token_data = {key: value[0] if value else None for key, value in parsed.items()} |
| 464 | + if scope := token_data.get("scope"): |
| 465 | + token_data["scope"] = scope.replace(",", " ") |
| 466 | + token_response = OAuthToken.model_validate(token_data) |
438 | 467 |
|
439 | 468 | # Validate scopes |
440 | 469 | if token_response.scope and self.context.client_metadata.scope: |
|
0 commit comments