Skip to content

Add basic OAuth support to Schema Registry #1919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-schemaregistry.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
attrs
cachetools
httpx>=0.26
authlib
11 changes: 10 additions & 1 deletion src/confluent_kafka/schema_registry/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
except ImportError:
pass

__all__ = ['SchemaRegistryError', 'SchemaParseException', 'UnknownType']
__all__ = ['SchemaRegistryError', 'OAuthTokenError', 'SchemaParseException', 'UnknownType']


class SchemaRegistryError(Exception):
Expand Down Expand Up @@ -53,3 +53,12 @@ def __str__(self):
return "{} (HTTP status code {}, SR code {})".format(self.error_message,
self.http_status_code,
self.error_code)


class OAuthTokenError(Exception):
"""Raised when an OAuth token cannot be retrieved."""
def __init__(self, message, status_code=None, response_text=None):
self.message = message
self.status_code = status_code
self.response_text = response_text
super().__init__(f"{message} (HTTP {status_code}): {response_text}")
102 changes: 101 additions & 1 deletion src/confluent_kafka/schema_registry/schema_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from cachetools import TTLCache, LRUCache
from httpx import Response

from .error import SchemaRegistryError
from authlib.integrations.httpx_client import OAuth2Client

from .error import SchemaRegistryError, OAuthTokenError

# TODO: consider adding `six` dependency or employing a compat file
# Python 2.7 is officially EOL so compatibility issue will be come more the norm.
Expand All @@ -60,6 +62,40 @@ def _urlencode(value: str) -> str:
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']


class _OAuthClient:
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
self.token = None
self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
self.token_endpoint = token_endpoint
self.max_retries = max_retries
self.retries_wait_ms = retries_wait_ms
self.retries_max_wait_ms = retries_max_wait_ms
self.token_expiry_threshold = 0.8

def token_expired(self):
expiry_window = self.token['expires_in'] * self.token_expiry_threshold

return self.token['expires_at'] < time.time() + expiry_window

def get_access_token(self) -> str:
if not self.token or self.token_expired():
self.generate_access_token()

return self.token['access_token']

def generate_access_token(self):
for i in range(self.max_retries + 1):
try:
self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
return
except Exception as e:
if i >= self.max_retries:
raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} "
f"attempts due to error: {str(e)}")
time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000)


class _BaseRestClient(object):

def __init__(self, conf: dict):
Expand Down Expand Up @@ -170,6 +206,59 @@ def __init__(self, conf: dict):
+ str(type(retries_max_wait_ms)))
self.retries_max_wait_ms = retries_max_wait_ms

self.oauth_client = None
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
if self.bearer_auth_credentials_source is not None:
self.auth = None
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
missing_headers = [header for header in headers if header not in conf_copy]
if missing_headers:
raise ValueError("Missing required bearer configuration properties: {}"
.format(", ".join(missing_headers)))

self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
if not isinstance(self.logical_cluster, str):
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))

self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
if not isinstance(self.identity_pool_id, str):
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))

if self.bearer_auth_credentials_source == 'OAUTHBEARER':
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
'bearer.auth.issuer.endpoint.url']
missing_properties = [prop for prop in properties_list if prop not in conf_copy]
if missing_properties:
raise ValueError("Missing required OAuth configuration properties: {}".
format(", ".join(missing_properties)))

self.client_id = conf_copy.pop('bearer.auth.client.id')
if not isinstance(self.client_id, string_type):
raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id)))

self.client_secret = conf_copy.pop('bearer.auth.client.secret')
if not isinstance(self.client_secret, string_type):
raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret)))

self.scope = conf_copy.pop('bearer.auth.scope')
if not isinstance(self.scope, string_type):
raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope)))

self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url')
if not isinstance(self.token_endpoint, string_type):
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
+ str(type(self.token_endpoint)))

self.oauth_client = _OAuthClient(self.client_id, self.client_secret, self.scope, self.token_endpoint,
self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms)

elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
if 'bearer.auth.token' not in conf_copy:
raise ValueError("Missing bearer.auth.token")
self.bearer_token = conf_copy.pop('bearer.auth.token')
if not isinstance(self.bearer_token, string_type):
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))

# Any leftover keys are unknown to _RestClient
if len(conf_copy) > 0:
raise ValueError("Unrecognized properties: {}"
Expand Down Expand Up @@ -209,6 +298,14 @@ def __init__(self, conf: dict):
timeout=self.timeout
)

def handle_bearer_auth(self, headers: dict):
token = self.bearer_token
if self.oauth_client:
token = self.oauth_client.get_access_token()
headers["Authorization"] = "Bearer {}".format(token)
headers['Confluent-Identity-Pool-Id'] = self.identity_pool_id
headers['target-sr-cluster'] = self.logical_cluster

def get(self, url: str, query: Optional[dict] = None) -> Any:
return self.send_request(url, method='GET', query=query)

Expand Down Expand Up @@ -256,6 +353,9 @@ def send_request(
headers = {'Content-Length': str(len(body)),
'Content-Type': "application/vnd.schemaregistry.v1+json"}

if self.bearer_auth_credentials_source:
self.handle_bearer_auth(headers)

response = None
for i, base_url in enumerate(self.base_urls):
try:
Expand Down
118 changes: 118 additions & 0 deletions tests/schema_registry/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
TEST_URL = 'http://SchemaRegistry:65534'
TEST_USERNAME = 'sr_user'
TEST_USER_PASSWORD = 'sr_user_secret'
TEST_POOL = 'sr_pool'
TEST_CLUSTER = 'lsrc-1234'
TEST_SCOPE = 'sr_scope'
TEST_ENDPOINT = 'http://oauth_endpoint'

"""
Tests to ensure all configurations are handled correctly.
Expand Down Expand Up @@ -112,6 +116,120 @@ def test_config_auth_userinfo_invalid():
SchemaRegistryClient(conf)


def test_bearer_config():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER"}

with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"):
SchemaRegistryClient(conf)


def test_oauth_bearer_config_missing():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL}

with pytest.raises(ValueError, match=r"Missing required OAuth configuration properties: (.*)"):
SchemaRegistryClient(conf)


def test_oauth_bearer_config_invalid():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': 1}

with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"):
SchemaRegistryClient(conf)

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': 1,
'bearer.auth.identity.pool.id': TEST_POOL}

with pytest.raises(TypeError, match=r"logical cluster must be a str, not (.*)"):
SchemaRegistryClient(conf)

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL,
'bearer.auth.client.id': 1,
'bearer.auth.client.secret': TEST_USER_PASSWORD,
'bearer.auth.scope': TEST_SCOPE,
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}

with pytest.raises(TypeError, match=r"bearer.auth.client.id must be a str, not (.*)"):
SchemaRegistryClient(conf)

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL,
'bearer.auth.client.id': TEST_USERNAME,
'bearer.auth.client.secret': 1,
'bearer.auth.scope': TEST_SCOPE,
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}

with pytest.raises(TypeError, match=r"bearer.auth.client.secret must be a str, not (.*)"):
SchemaRegistryClient(conf)

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL,
'bearer.auth.client.id': TEST_USERNAME,
'bearer.auth.client.secret': TEST_USER_PASSWORD,
'bearer.auth.scope': 1,
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}

with pytest.raises(TypeError, match=r"bearer.auth.scope must be a str, not (.*)"):
SchemaRegistryClient(conf)

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL,
'bearer.auth.client.id': TEST_USERNAME,
'bearer.auth.client.secret': TEST_USER_PASSWORD,
'bearer.auth.scope': TEST_SCOPE,
'bearer.auth.issuer.endpoint.url': 1}

with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.url must be a str, not (.*)"):
SchemaRegistryClient(conf)


def test_oauth_bearer_config_valid():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': "OAUTHBEARER",
'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL,
'bearer.auth.client.id': TEST_USERNAME,
'bearer.auth.client.secret': TEST_USER_PASSWORD,
'bearer.auth.scope': TEST_SCOPE,
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}

client = SchemaRegistryClient(conf)

assert client._rest_client.logical_cluster == TEST_CLUSTER
assert client._rest_client.identity_pool_id == TEST_POOL
assert client._rest_client.client_id == TEST_USERNAME
assert client._rest_client.client_secret == TEST_USER_PASSWORD
assert client._rest_client.scope == TEST_SCOPE
assert client._rest_client.token_endpoint == TEST_ENDPOINT


def test_static_bearer_config():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': 'STATIC_TOKEN',
'bearer.auth.logical.cluster': 'lsrc',
'bearer.auth.identity.pool.id': 'pool_id'}

with pytest.raises(ValueError, match='Missing bearer.auth.token'):
SchemaRegistryClient(conf)


def test_config_unknown_prop():
conf = {'url': TEST_URL,
'basic.auth.credentials.source': 'SASL_INHERIT',
Expand Down
57 changes: 57 additions & 0 deletions tests/schema_registry/test_oauth_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import time
from unittest.mock import Mock, patch

from confluent_kafka.schema_registry.schema_registry_client import _OAuthClient
from confluent_kafka.schema_registry.error import OAuthTokenError

"""
Tests to ensure OAuth client is set up correctly.

"""


def test_expiry():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
assert not oauth_client.token_expired()
time.sleep(1.5)
assert oauth_client.token_expired()


def test_get_token():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
assert not oauth_client.token

def update_token1():
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}

def update_token2():
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}

oauth_client.generate_access_token = Mock(side_effect=update_token1)
oauth_client.get_access_token()
assert oauth_client.generate_access_token.call_count == 1
assert oauth_client.token['access_token'] == '123'

oauth_client.generate_access_token = Mock(side_effect=update_token2)
oauth_client.get_access_token()
# Call count resets to 1 after reassigning generate_access_token
assert oauth_client.generate_access_token.call_count == 1
assert oauth_client.token['access_token'] == '1234'

oauth_client.get_access_token()
assert oauth_client.generate_access_token.call_count == 1


def test_generate_token_retry_logic():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 5, 1000, 20000)

with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):

with pytest.raises(OAuthTokenError):
oauth_client.generate_access_token()

assert mock_sleep.call_count == 5
assert mock_jitter.call_count == 5