Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions supabase/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
options.headers.update(self._get_auth_headers())
self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand Down Expand Up @@ -99,9 +96,7 @@ async def create(
supabase_key: str,
options: ClientOptions = ClientOptions(),
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = await client._get_token_header()
return client
return cls(supabase_url, supabase_key, options)

def table(self, table_name: str) -> AsyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -144,7 +139,6 @@ def rpc(
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
Expand All @@ -157,21 +151,19 @@ def postgrest(self):
@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
headers=self.options.headers,
storage_client_timeout=self.options.storage_client_timeout,
)
return self._storage

@property
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._functions = AsyncFunctionsClient(self.functions_url, headers)
self._functions = AsyncFunctionsClient(
self.functions_url, self.options.headers
)
return self._functions

# async def remove_subscription_helper(resolve):
Expand Down Expand Up @@ -245,26 +237,17 @@ def _init_postgrest_client(
)

def _create_auth_header(self, token: str):
return {
"Authorization": f"Bearer {token}",
}
return f"Bearer {token}"

def _get_auth_headers(self) -> Dict[str, str]:
"""Helper method to get auth headers."""
return {
"apiKey": self.supabase_key,
"Authorization": f"Bearer {self.supabase_key}",
"Authorization": self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
),
}

async def _get_token_header(self):
try:
session = await self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return self._create_auth_header(access_token)

def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Union[Session, None]
):
Expand All @@ -276,7 +259,7 @@ def _listen_to_auth_events(
self._functions = None
access_token = session.access_token if session else self.supabase_key

self._auth_token = self._create_auth_header(access_token)
self.options.headers["Authorization"] = self._create_auth_header(access_token)


async def create_client(
Expand Down
39 changes: 11 additions & 28 deletions supabase/_sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,8 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self._auth_token = {
"Authorization": f"Bearer {supabase_key}",
}
options.headers.update(self._get_auth_headers())
self.options = options
options.headers.update(self._get_auth_headers())
self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand Down Expand Up @@ -99,9 +96,7 @@ def create(
supabase_key: str,
options: ClientOptions = ClientOptions(),
):
client = cls(supabase_url, supabase_key, options)
client._auth_token = client._get_token_header()
return client
return cls(supabase_url, supabase_key, options)

def table(self, table_name: str) -> SyncRequestBuilder:
"""Perform a table operation.
Expand Down Expand Up @@ -144,7 +139,6 @@ def rpc(
@property
def postgrest(self):
if self._postgrest is None:
self.options.headers.update(self._auth_token)
self._postgrest = self._init_postgrest_client(
rest_url=self.rest_url,
headers=self.options.headers,
Expand All @@ -157,21 +151,19 @@ def postgrest(self):
@property
def storage(self):
if self._storage is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._storage = self._init_storage_client(
storage_url=self.storage_url,
headers=headers,
headers=self.options.headers,
storage_client_timeout=self.options.storage_client_timeout,
)
return self._storage

@property
def functions(self):
if self._functions is None:
headers = self._get_auth_headers()
headers.update(self._auth_token)
self._functions = SyncFunctionsClient(self.functions_url, headers)
self._functions = SyncFunctionsClient(
self.functions_url, self.options.headers
)
return self._functions

# async def remove_subscription_helper(resolve):
Expand Down Expand Up @@ -245,26 +237,17 @@ def _init_postgrest_client(
)

def _create_auth_header(self, token: str):
return {
"Authorization": f"Bearer {token}",
}
return f"Bearer {token}"

def _get_auth_headers(self) -> Dict[str, str]:
"""Helper method to get auth headers."""
return {
"apiKey": self.supabase_key,
"Authorization": f"Bearer {self.supabase_key}",
"Authorization": self.options.headers.get(
"Authorization", self._create_auth_header(self.supabase_key)
),
}

def _get_token_header(self):
try:
session = self.auth.get_session()
access_token = session.access_token
except Exception as err:
access_token = self.supabase_key

return self._create_auth_header(access_token)

def _listen_to_auth_events(
self, event: AuthChangeEvent, session: Union[Session, None]
):
Expand All @@ -276,7 +259,7 @@ def _listen_to_auth_events(
self._functions = None
access_token = session.access_token if session else self.supabase_key

self._auth_token = self._create_auth_header(access_token)
self.options.headers["Authorization"] = self._create_auth_header(access_token)


def create_client(
Expand Down
78 changes: 76 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

import os
from typing import Any
from unittest.mock import MagicMock

import pytest

from supabase import Client, create_client
from supabase.lib.client_options import ClientOptions


@pytest.mark.xfail(
reason="None of these values should be able to instantiate a client object"
Expand All @@ -12,6 +17,75 @@
@pytest.mark.parametrize("key", ["", None, "valeefgpoqwjgpj", 139, -1, {}, []])
def test_incorrect_values_dont_instantiate_client(url: Any, key: Any) -> None:
"""Ensure we can't instantiate client with invalid values."""
from supabase import Client, create_client

_: Client = create_client(url, key)


def test_uses_key_as_authorization_header_by_default() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a test case for when the Authorization header is explicitly set to None.

This would help ensure that the default behavior is correctly handled even when the header is explicitly set to None, rather than just not provided.

Suggested change
def test_uses_key_as_authorization_header_by_default() -> None:
def test_uses_key_as_authorization_header_when_explicitly_set_to_none() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")
options = ClientOptions(headers={"Authorization": None})
client = create_client(url, key, options)
assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") is None

url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

client = create_client(url, key)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == f"Bearer {key}"

assert client.postgrest.session.headers.get("apiKey") == key
assert client.postgrest.session.headers.get("Authorization") == f"Bearer {key}"

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == f"Bearer {key}"

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == f"Bearer {key}"


def test_supports_setting_a_global_authorization_header() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Please add a test case to verify the behavior when an invalid JWT is provided in the Authorization header.

This test would help ensure that the system behaves as expected when an invalid JWT is used, potentially rejecting the request or handling it gracefully.

Suggested change
def test_supports_setting_a_global_authorization_header() -> None:
def test_rejects_invalid_jwt() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")
invalid_jwt = "invalid_jwt"
client = create_client(url, key, ClientOptions(headers={"Authorization": f"Bearer {invalid_jwt}"}))
response = client.postgrest.get("some_endpoint")
assert response.status_code == 401

url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

authorization = f"Bearer secretuserjwt"

options = ClientOptions(headers={"Authorization": authorization})

client = create_client(url, key, options)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == authorization

assert client.postgrest.session.headers.get("apiKey") == key
assert client.postgrest.session.headers.get("Authorization") == authorization

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == authorization

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == authorization


def test_updates_the_authorization_header_on_auth_events() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test to verify header updates when the auth event is 'SIGNED_OUT'.

This would ensure that headers are correctly reset or handled when a user signs out.

Suggested change
def test_updates_the_authorization_header_on_auth_events() -> None:
def test_authorization_header_reset_on_sign_out() -> None:
client._listen_to_auth_events("SIGNED_OUT", MagicMock(access_token=None))
assert client.options.headers.get("Authorization") is None
assert client.postgrest.session.headers.get("Authorization") is None
assert client.auth._headers.get("Authorization") is None
assert client.storage.session.headers.get("Authorization") is None

url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

client = create_client(url, key)

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == f"Bearer {key}"

mock_session = MagicMock(access_token="secretuserjwt")
client._listen_to_auth_events("SIGNED_IN", mock_session)

updated_authorization = f"Bearer {mock_session.access_token}"

assert client.options.headers.get("apiKey") == key
assert client.options.headers.get("Authorization") == updated_authorization

assert client.postgrest.session.headers.get("apiKey") == key
assert (
client.postgrest.session.headers.get("Authorization") == updated_authorization
)

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == updated_authorization

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == updated_authorization