Skip to content

Add support for custom OAuth functions #1925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 100 additions & 32 deletions src/confluent_kafka/schema_registry/schema_registry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import abc
import json
import logging
import random
Expand All @@ -30,7 +30,7 @@
from enum import Enum
from threading import Lock
from typing import List, Dict, Type, TypeVar, \
cast, Optional, Union, Any, Tuple
cast, Optional, Union, Any, Tuple, Callable

from cachetools import TTLCache, LRUCache
from httpx import Response
Expand Down Expand Up @@ -62,18 +62,50 @@ def _urlencode(value: str) -> str:
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']


class _OAuthClient:
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
class _BearerFieldProvider(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_bearer_fields(self) -> dict:
raise NotImplementedError


class _StaticFieldProvider(_BearerFieldProvider):
def __init__(self, token: str, logical_cluster: str, identity_pool: str):
self.token = token
self.logical_cluster = logical_cluster
self.identity_pool = identity_pool

def get_bearer_fields(self) -> dict:
return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster,
'bearer.auth.identity.pool.id': self.identity_pool}


class _CustomOAuthClient(_BearerFieldProvider):
def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict):
self.custom_function = custom_function
self.custom_config = custom_config

def get_bearer_fields(self) -> dict:
return self.custom_function(self.custom_config)


class _OAuthClient(_BearerFieldProvider):
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str,
identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
self.token = None
self.logical_cluster = logical_cluster
self.identity_pool = identity_pool
self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
self.token_endpoint = token_endpoint
self.max_retries = max_retries
self.retries_wait_ms = retries_wait_ms
self.retries_max_wait_ms = retries_max_wait_ms
self.token_expiry_threshold = 0.8

def token_expired(self):
def get_bearer_fields(self) -> dict:
return {'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster,
'bearer.auth.identity.pool.id': self.identity_pool}

def token_expired(self) -> bool:
expiry_window = self.token['expires_in'] * self.token_expiry_threshold

return self.token['expires_at'] < time.time() + expiry_window
Expand All @@ -84,7 +116,7 @@ def get_access_token(self) -> str:

return self.token['access_token']

def generate_access_token(self):
def generate_access_token(self) -> None:
for i in range(self.max_retries + 1):
try:
self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
Expand Down Expand Up @@ -206,23 +238,27 @@ def __init__(self, conf: dict):
+ str(type(retries_max_wait_ms)))
self.retries_max_wait_ms = retries_max_wait_ms

self.oauth_client = None
self.bearer_field_provider = None
logical_cluster = None
identity_pool = None
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
if self.bearer_auth_credentials_source is not None:
self.auth = None
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
missing_headers = [header for header in headers if header not in conf_copy]
if missing_headers:
raise ValueError("Missing required bearer configuration properties: {}"
.format(", ".join(missing_headers)))

self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
if not isinstance(self.logical_cluster, str):
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}:
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
missing_headers = [header for header in headers if header not in conf_copy]
if missing_headers:
raise ValueError("Missing required bearer configuration properties: {}"
.format(", ".join(missing_headers)))

self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
if not isinstance(self.identity_pool_id, str):
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
if not isinstance(logical_cluster, str):
raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster)))

identity_pool = conf_copy.pop('bearer.auth.identity.pool.id')
if not isinstance(identity_pool, str):
raise TypeError("identity pool id must be a str, not " + str(type(identity_pool)))

if self.bearer_auth_credentials_source == 'OAUTHBEARER':
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
Expand All @@ -249,15 +285,38 @@ def __init__(self, conf: dict):
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
+ str(type(self.token_endpoint)))

self.oauth_client = _OAuthClient(self.client_id, self.client_secret, self.scope, self.token_endpoint,
self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms)

self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope,
self.token_endpoint, logical_cluster, identity_pool,
self.max_retries, self.retries_wait_ms,
self.retries_max_wait_ms)
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
if 'bearer.auth.token' not in conf_copy:
raise ValueError("Missing bearer.auth.token")
self.bearer_token = conf_copy.pop('bearer.auth.token')
if not isinstance(self.bearer_token, string_type):
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))
static_token = conf_copy.pop('bearer.auth.token')
self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool)
if not isinstance(static_token, string_type):
raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token)))
elif self.bearer_auth_credentials_source == 'CUSTOM':
custom_bearer_properties = ['bearer.auth.custom.provider.function',
'bearer.auth.custom.provider.config']
missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy]
if missing_custom_properties:
raise ValueError("Missing required custom OAuth configuration properties: {}".
format(", ".join(missing_custom_properties)))

custom_function = conf_copy.pop('bearer.auth.custom.provider.function')
if not callable(custom_function):
raise TypeError("bearer.auth.custom.provider.function must be a callable, not "
+ str(type(custom_function)))

custom_config = conf_copy.pop('bearer.auth.custom.provider.config')
if not isinstance(custom_config, dict):
raise TypeError("bearer.auth.custom.provider.config must be a dict, not "
+ str(type(custom_config)))

self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config)
else:
raise ValueError('Unrecognized bearer.auth.credentials.source')

# Any leftover keys are unknown to _RestClient
if len(conf_copy) > 0:
Expand Down Expand Up @@ -298,13 +357,22 @@ def __init__(self, conf: dict):
timeout=self.timeout
)

def handle_bearer_auth(self, headers: dict):
token = self.bearer_token
if self.oauth_client:
token = self.oauth_client.get_access_token()
headers["Authorization"] = "Bearer {}".format(token)
headers['Confluent-Identity-Pool-Id'] = self.identity_pool_id
headers['target-sr-cluster'] = self.logical_cluster
def handle_bearer_auth(self, headers: dict) -> None:
bearer_fields = self.bearer_field_provider.get_bearer_fields()
required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster']

missing_fields = []
for field in required_fields:
if field not in bearer_fields:
missing_fields.append(field)

if missing_fields:
raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}"
.format(", ".join(missing_fields)))

headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token'])
headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id']
headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster']

def get(self, url: str, query: Optional[dict] = None) -> Any:
return self.send_request(url, method='GET', query=query)
Expand Down
142 changes: 142 additions & 0 deletions tests/schema_registry/test_bearer_field_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2025 Confluent Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import time
from unittest.mock import Mock, patch

from confluent_kafka.schema_registry.schema_registry_client import (_OAuthClient, _StaticFieldProvider,
_CustomOAuthClient, SchemaRegistryClient)
from confluent_kafka.schema_registry.error import OAuthTokenError

"""
Tests to ensure OAuth client is set up correctly.

"""


def custom_oauth_function(config: dict) -> dict:
return config


TEST_TOKEN = 'token123'
TEST_CLUSTER = 'lsrc-cluster'
TEST_POOL = 'pool-id'
TEST_FUNCTION = custom_oauth_function
TEST_CONFIG = {'bearer.auth.token': TEST_TOKEN, 'bearer.auth.logical.cluster': TEST_CLUSTER,
'bearer.auth.identity.pool.id': TEST_POOL}
TEST_URL = 'http://SchemaRegistry:65534'


def test_expiry():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
assert not oauth_client.token_expired()
time.sleep(1.5)
assert oauth_client.token_expired()


def test_get_token():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 2, 1000, 20000)

def update_token1():
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}

def update_token2():
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}

oauth_client.generate_access_token = Mock(side_effect=update_token1)
oauth_client.get_access_token()
assert oauth_client.generate_access_token.call_count == 1
assert oauth_client.token['access_token'] == '123'

oauth_client.generate_access_token = Mock(side_effect=update_token2)
oauth_client.get_access_token()
# Call count resets to 1 after reassigning generate_access_token
assert oauth_client.generate_access_token.call_count == 1
assert oauth_client.token['access_token'] == '1234'

oauth_client.get_access_token()
assert oauth_client.generate_access_token.call_count == 1


def test_generate_token_retry_logic():
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000)

with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):

with pytest.raises(OAuthTokenError):
oauth_client.generate_access_token()

assert mock_sleep.call_count == 5
assert mock_jitter.call_count == 5


def test_static_field_provider():
static_field_provider = _StaticFieldProvider(TEST_TOKEN, TEST_CLUSTER, TEST_POOL)
bearer_fields = static_field_provider.get_bearer_fields()

assert bearer_fields == TEST_CONFIG


def test_custom_oauth_client():
custom_oauth_client = _CustomOAuthClient(TEST_FUNCTION, TEST_CONFIG)

assert custom_oauth_client.get_bearer_fields() == custom_oauth_client.get_bearer_fields()


def test_bearer_field_headers_missing():
def empty_custom(config):
return {}

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': 'CUSTOM',
'bearer.auth.custom.provider.function': empty_custom,
'bearer.auth.custom.provider.config': TEST_CONFIG}

headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
" application/vnd.schemaregistry+json,"
" application/json"}

client = SchemaRegistryClient(conf)

with pytest.raises(ValueError, match=r"Missing required bearer auth fields, "
r"needs to be set in config or custom function: (.*)"):
client._rest_client.handle_bearer_auth(headers)


def test_bearer_field_headers_valid():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': 'CUSTOM',
'bearer.auth.custom.provider.function': TEST_FUNCTION,
'bearer.auth.custom.provider.config': TEST_CONFIG}

client = SchemaRegistryClient(conf)

headers = {'Accept': "application/vnd.schemaregistry.v1+json,"
" application/vnd.schemaregistry+json,"
" application/json"}

client._rest_client.handle_bearer_auth(headers)

assert 'Authorization' in headers
assert 'Confluent-Identity-Pool-Id' in headers
assert 'target-sr-cluster' in headers
assert headers['Authorization'] == "Bearer {}".format(TEST_CONFIG['bearer.auth.token'])
assert headers['Confluent-Identity-Pool-Id'] == TEST_CONFIG['bearer.auth.identity.pool.id']
assert headers['target-sr-cluster'] == TEST_CONFIG['bearer.auth.logical.cluster']
27 changes: 25 additions & 2 deletions tests/schema_registry/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,6 @@ def test_oauth_bearer_config_valid():

client = SchemaRegistryClient(conf)

assert client._rest_client.logical_cluster == TEST_CLUSTER
assert client._rest_client.identity_pool_id == TEST_POOL
assert client._rest_client.client_id == TEST_USERNAME
assert client._rest_client.client_secret == TEST_USER_PASSWORD
assert client._rest_client.scope == TEST_SCOPE
Expand All @@ -230,6 +228,31 @@ def test_static_bearer_config():
SchemaRegistryClient(conf)


def test_custom_bearer_config():
conf = {'url': TEST_URL,
'bearer.auth.credentials.source': 'CUSTOM'}

with pytest.raises(ValueError, match='Missing required custom OAuth configuration properties:'):
SchemaRegistryClient(conf)


def test_custom_bearer_config_valid():
def custom_function(config: dict):
return {}

custom_config = {}

conf = {'url': TEST_URL,
'bearer.auth.credentials.source': 'CUSTOM',
'bearer.auth.custom.provider.function': custom_function,
'bearer.auth.custom.provider.config': custom_config}

client = SchemaRegistryClient(conf)

assert client._rest_client.bearer_field_provider.custom_function == custom_function
assert client._rest_client.bearer_field_provider.custom_config == custom_config


def test_config_unknown_prop():
conf = {'url': TEST_URL,
'basic.auth.credentials.source': 'SASL_INHERIT',
Expand Down
Loading