Skip to content

Commit

Permalink
fix(logs): redact sensitive headers (#1850)
Browse files Browse the repository at this point in the history
  • Loading branch information
kristapratico authored Nov 6, 2024
1 parent 74522ad commit 466608f
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
pretty = True
show_error_codes = True

# Exclude _files.py because mypy isn't smart enough to apply
# Exclude _files.py and _logs.py because mypy isn't smart enough to apply
# the correct type narrowing and as this is an internal module
# it's fine to just use Pyright.
exclude = ^(src/openai/_files\.py|_dev/.*\.py)$
exclude = ^(src/openai/_files\.py|src/openai/_utils/_logs\.py|_dev/.*\.py)$

strict_equality = True
implicit_reexport = True
Expand Down
3 changes: 2 additions & 1 deletion src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
HttpxRequestFiles,
ModelBuilderProtocol,
)
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._utils import SensitiveHeadersFilter, is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._compat import model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
Expand Down Expand Up @@ -90,6 +90,7 @@
from ._legacy_response import LegacyAPIResponse

log: logging.Logger = logging.getLogger(__name__)
log.addFilter(SensitiveHeadersFilter())

# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
Expand Down
1 change: 1 addition & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._logs import SensitiveHeadersFilter as SensitiveHeadersFilter
from ._sync import asyncify as asyncify
from ._proxy import LazyProxy as LazyProxy
from ._utils import (
Expand Down
17 changes: 17 additions & 0 deletions src/openai/_utils/_logs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
import logging
from typing_extensions import override

from ._utils import is_dict

logger: logging.Logger = logging.getLogger("openai")
httpx_logger: logging.Logger = logging.getLogger("httpx")


SENSITIVE_HEADERS = {"api-key", "authorization"}


def _basic_config() -> None:
# e.g. [2023-10-05 14:12:26 - openai._base_client:818 - DEBUG] HTTP Request: POST http://127.0.0.1:4010/foo/bar "200 OK"
logging.basicConfig(
Expand All @@ -23,3 +29,14 @@ def setup_logging() -> None:
_basic_config()
logger.setLevel(logging.INFO)
httpx_logger.setLevel(logging.INFO)


class SensitiveHeadersFilter(logging.Filter):
@override
def filter(self, record: logging.LogRecord) -> bool:
if is_dict(record.args) and "headers" in record.args and is_dict(record.args["headers"]):
headers = record.args["headers"] = {**record.args["headers"]}
for header in headers:
if str(header).lower() in SENSITIVE_HEADERS:
headers[header] = "<redacted>"
return True
101 changes: 101 additions & 0 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import Union, cast
from typing_extensions import Literal, Protocol

import httpx
import pytest
from respx import MockRouter

from openai._utils import SensitiveHeadersFilter, is_dict
from openai._models import FinalRequestOptions
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI

Expand Down Expand Up @@ -148,3 +150,102 @@ def token_provider() -> str:

assert calls[0].request.headers.get("Authorization") == "Bearer first"
assert calls[1].request.headers.get("Authorization") == "Bearer second"


class TestAzureLogging:

@pytest.fixture(autouse=True)
def logger_with_filter(self) -> logging.Logger:
logger = logging.getLogger("openai")
logger.setLevel(logging.DEBUG)
logger.addFilter(SensitiveHeadersFilter())
return logger

@pytest.mark.respx()
def test_azure_api_key_redacted(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
).mock(
return_value=httpx.Response(200, json={"model": "gpt-4"})
)

client = AzureOpenAI(
api_version="2024-06-01",
api_key="example_api_key",
azure_endpoint="https://example-resource.azure.openai.com",
)

with caplog.at_level(logging.DEBUG):
client.chat.completions.create(messages=[], model="gpt-4")

for record in caplog.records:
if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
assert record.args["headers"]["api-key"] == "<redacted>"


@pytest.mark.respx()
def test_azure_bearer_token_redacted(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
).mock(
return_value=httpx.Response(200, json={"model": "gpt-4"})
)

client = AzureOpenAI(
api_version="2024-06-01",
azure_ad_token="example_token",
azure_endpoint="https://example-resource.azure.openai.com",
)

with caplog.at_level(logging.DEBUG):
client.chat.completions.create(messages=[], model="gpt-4")

for record in caplog.records:
if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
assert record.args["headers"]["Authorization"] == "<redacted>"


@pytest.mark.asyncio
@pytest.mark.respx()
async def test_azure_api_key_redacted_async(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
).mock(
return_value=httpx.Response(200, json={"model": "gpt-4"})
)

client = AsyncAzureOpenAI(
api_version="2024-06-01",
api_key="example_api_key",
azure_endpoint="https://example-resource.azure.openai.com",
)

with caplog.at_level(logging.DEBUG):
await client.chat.completions.create(messages=[], model="gpt-4")

for record in caplog.records:
if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
assert record.args["headers"]["api-key"] == "<redacted>"


@pytest.mark.asyncio
@pytest.mark.respx()
async def test_azure_bearer_token_redacted_async(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
).mock(
return_value=httpx.Response(200, json={"model": "gpt-4"})
)

client = AsyncAzureOpenAI(
api_version="2024-06-01",
azure_ad_token="example_token",
azure_endpoint="https://example-resource.azure.openai.com",
)

with caplog.at_level(logging.DEBUG):
await client.chat.completions.create(messages=[], model="gpt-4")

for record in caplog.records:
if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
assert record.args["headers"]["Authorization"] == "<redacted>"
100 changes: 100 additions & 0 deletions tests/test_utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
from typing import Any, Dict, cast

import pytest

from openai._utils import SensitiveHeadersFilter


@pytest.fixture
def logger_with_filter() -> logging.Logger:
logger = logging.getLogger("test_logger")
logger.setLevel(logging.DEBUG)
logger.addFilter(SensitiveHeadersFilter())
return logger


def test_keys_redacted(logger_with_filter: logging.Logger, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
logger_with_filter.debug(
"Request options: %s",
{
"method": "post",
"url": "chat/completions",
"headers": {"api-key": "12345", "Authorization": "Bearer token"},
},
)

log_record = cast(Dict[str, Any], caplog.records[0].args)
assert log_record["method"] == "post"
assert log_record["url"] == "chat/completions"
assert log_record["headers"]["api-key"] == "<redacted>"
assert log_record["headers"]["Authorization"] == "<redacted>"
assert (
caplog.messages[0]
== "Request options: {'method': 'post', 'url': 'chat/completions', 'headers': {'api-key': '<redacted>', 'Authorization': '<redacted>'}}"
)


def test_keys_redacted_case_insensitive(logger_with_filter: logging.Logger, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
logger_with_filter.debug(
"Request options: %s",
{
"method": "post",
"url": "chat/completions",
"headers": {"Api-key": "12345", "authorization": "Bearer token"},
},
)

log_record = cast(Dict[str, Any], caplog.records[0].args)
assert log_record["method"] == "post"
assert log_record["url"] == "chat/completions"
assert log_record["headers"]["Api-key"] == "<redacted>"
assert log_record["headers"]["authorization"] == "<redacted>"
assert (
caplog.messages[0]
== "Request options: {'method': 'post', 'url': 'chat/completions', 'headers': {'Api-key': '<redacted>', 'authorization': '<redacted>'}}"
)


def test_no_headers(logger_with_filter: logging.Logger, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
logger_with_filter.debug(
"Request options: %s",
{"method": "post", "url": "chat/completions"},
)

log_record = cast(Dict[str, Any], caplog.records[0].args)
assert log_record["method"] == "post"
assert log_record["url"] == "chat/completions"
assert "api-key" not in log_record
assert "Authorization" not in log_record
assert caplog.messages[0] == "Request options: {'method': 'post', 'url': 'chat/completions'}"


def test_headers_without_sensitive_info(logger_with_filter: logging.Logger, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
logger_with_filter.debug(
"Request options: %s",
{
"method": "post",
"url": "chat/completions",
"headers": {"custom": "value"},
},
)

log_record = cast(Dict[str, Any], caplog.records[0].args)
assert log_record["method"] == "post"
assert log_record["url"] == "chat/completions"
assert log_record["headers"] == {"custom": "value"}
assert (
caplog.messages[0]
== "Request options: {'method': 'post', 'url': 'chat/completions', 'headers': {'custom': 'value'}}"
)


def test_standard_debug_msg(logger_with_filter: logging.Logger, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
logger_with_filter.debug("Sending HTTP Request: %s %s", "POST", "chat/completions")
assert caplog.messages[0] == "Sending HTTP Request: POST chat/completions"

0 comments on commit 466608f

Please sign in to comment.