Skip to content

Commit

Permalink
Add async and sync header tests for Secure class.
Browse files Browse the repository at this point in the history
  • Loading branch information
cak committed Oct 17, 2024
1 parent 487f751 commit 89046f6
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions tests/secure/test_secure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import unittest

from secure import (
Expand Down Expand Up @@ -44,6 +45,20 @@ class MockResponseNoHeaders:


class TestSecure(unittest.TestCase):
def setUp(self):
# Initialize Secure with some test headers
self.secure = Secure(
custom=[
CustomHeader("X-Test-Header-1", "Value1"),
CustomHeader("X-Test-Header-2", "Value2"),
]
)
# Precompute headers dictionary
self.secure.headers = {
header.header_name: header.header_value
for header in self.secure.headers_list
}

def test_with_default_headers(self):
"""Test that default headers are correctly applied."""
secure_headers = Secure.with_default_headers()
Expand Down Expand Up @@ -210,8 +225,6 @@ def test_async_set_headers(self):
async def mock_set_headers():
await secure_headers.set_headers_async(response)

import asyncio

asyncio.run(mock_set_headers())

# Verify that headers are set asynchronously
Expand All @@ -235,43 +248,43 @@ async def mock_set_headers():

def test_set_headers_with_set_header_method(self):
"""Test setting headers on a response object with set_header method."""
secure_headers = Secure.with_default_headers()
response = MockResponseWithSetHeader()

# Apply the headers to the response object
secure_headers.set_headers(response)
self.secure.set_headers(response)

# Verify that headers are set using set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, self.secure.headers)
# Ensure set_header was called correct number of times
self.assertEqual(len(response.header_storage), len(self.secure.headers))

self.assertIn("X-Content-Type-Options", response.header_storage)
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
def test_set_headers_with_headers_dict(self):
"""Test set_headers with a response object that has a headers dictionary."""
response = MockResponse()
self.secure.set_headers(response)

def test_async_set_headers_with_async_set_header_method(self):
"""Test async setting headers on a response object with async set_header method."""
secure_headers = Secure.with_default_headers()
response = MockResponseAsyncSetHeader()
# Verify that headers are set
self.assertEqual(response.headers, self.secure.headers)

async def mock_set_headers():
await secure_headers.set_headers_async(response)
def test_set_headers_async_with_async_set_header(self):
"""Test set_headers_async with a response object that has an asynchronous set_header method."""
response = MockResponseAsyncSetHeader()

import asyncio
async def test_async():
await self.secure.set_headers_async(response)

asyncio.run(mock_set_headers())
asyncio.run(test_async())

# Verify that headers are set using async set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, self.secure.headers)
# Ensure set_header was called correct number of times
self.assertEqual(len(response.header_storage), len(self.secure.headers))

def test_set_headers_async_with_headers_dict(self):
"""Test set_headers_async with a response object that has a headers dictionary."""
response = MockResponse()
asyncio.run(self.secure.set_headers_async(response))

self.assertIn("X-Content-Type-Options", response.header_storage)
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
# Verify that headers are set
self.assertEqual(response.headers, self.secure.headers)

def test_set_headers_missing_interface(self):
"""Test that an error is raised when response object lacks required methods."""
Expand All @@ -286,6 +299,12 @@ def test_set_headers_missing_interface(self):
str(context.exception),
)

def test_set_headers_with_async_set_header_in_sync_context(self):
"""Test set_headers raises RuntimeError when encountering async set_header in sync context."""
response = MockResponseAsyncSetHeader()
with self.assertRaises(RuntimeError):
self.secure.set_headers(response)

def test_set_headers_overwrites_existing_headers(self):
"""Test that existing headers are overwritten by Secure."""
secure_headers = Secure.with_default_headers()
Expand Down Expand Up @@ -347,10 +366,10 @@ def test_invalid_preset(self):

def test_empty_secure_instance(self):
"""Test that an empty Secure instance does not set any headers."""
secure_headers = Secure()
self.secure = Secure()
response = MockResponse()

secure_headers.set_headers(response)
self.secure.set_headers(response)
self.assertEqual(len(response.headers), 0)

def test_multiple_custom_headers(self):
Expand Down Expand Up @@ -430,16 +449,10 @@ def test_set_headers_async_with_sync_set_header(self):
async def mock_set_headers():
await secure_headers.set_headers_async(response)

import asyncio

asyncio.run(mock_set_headers())

# Verify that headers are set using set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, secure_headers.headers)

def test_set_headers_with_no_headers_or_set_header(self):
"""Test that an error is raised when response lacks both headers and set_header."""
Expand Down

0 comments on commit 89046f6

Please sign in to comment.