Skip to content

Commit 4a3694b

Browse files
Add basic OAuth support to Schema Registry (#1919)
* Add basic OAuth support to Schema Registry * Add testing for OAuth SR * Fix linting * Build and fix errors * fix aligning * Increase code coverage * Run flake8 * Update expiry threshold and field names * Update scope config name
1 parent e119676 commit 4a3694b

File tree

5 files changed

+287
-2
lines changed

5 files changed

+287
-2
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
attrs
22
cachetools
33
httpx>=0.26
4+
authlib

src/confluent_kafka/schema_registry/error.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
except ImportError:
2121
pass
2222

23-
__all__ = ['SchemaRegistryError', 'SchemaParseException', 'UnknownType']
23+
__all__ = ['SchemaRegistryError', 'OAuthTokenError', 'SchemaParseException', 'UnknownType']
2424

2525

2626
class SchemaRegistryError(Exception):
@@ -53,3 +53,12 @@ def __str__(self):
5353
return "{} (HTTP status code {}, SR code {})".format(self.error_message,
5454
self.http_status_code,
5555
self.error_code)
56+
57+
58+
class OAuthTokenError(Exception):
59+
"""Raised when an OAuth token cannot be retrieved."""
60+
def __init__(self, message, status_code=None, response_text=None):
61+
self.message = message
62+
self.status_code = status_code
63+
self.response_text = response_text
64+
super().__init__(f"{message} (HTTP {status_code}): {response_text}")

src/confluent_kafka/schema_registry/schema_registry_client.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from cachetools import TTLCache, LRUCache
3636
from httpx import Response
3737

38-
from .error import SchemaRegistryError
38+
from authlib.integrations.httpx_client import OAuth2Client
39+
40+
from .error import SchemaRegistryError, OAuthTokenError
3941

4042
# TODO: consider adding `six` dependency or employing a compat file
4143
# Python 2.7 is officially EOL so compatibility issue will be come more the norm.
@@ -60,6 +62,40 @@ def _urlencode(value: str) -> str:
6062
VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO']
6163

6264

65+
class _OAuthClient:
66+
def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str,
67+
max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int):
68+
self.token = None
69+
self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope)
70+
self.token_endpoint = token_endpoint
71+
self.max_retries = max_retries
72+
self.retries_wait_ms = retries_wait_ms
73+
self.retries_max_wait_ms = retries_max_wait_ms
74+
self.token_expiry_threshold = 0.8
75+
76+
def token_expired(self):
77+
expiry_window = self.token['expires_in'] * self.token_expiry_threshold
78+
79+
return self.token['expires_at'] < time.time() + expiry_window
80+
81+
def get_access_token(self) -> str:
82+
if not self.token or self.token_expired():
83+
self.generate_access_token()
84+
85+
return self.token['access_token']
86+
87+
def generate_access_token(self):
88+
for i in range(self.max_retries + 1):
89+
try:
90+
self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials')
91+
return
92+
except Exception as e:
93+
if i >= self.max_retries:
94+
raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} "
95+
f"attempts due to error: {str(e)}")
96+
time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000)
97+
98+
6399
class _BaseRestClient(object):
64100

65101
def __init__(self, conf: dict):
@@ -170,6 +206,59 @@ def __init__(self, conf: dict):
170206
+ str(type(retries_max_wait_ms)))
171207
self.retries_max_wait_ms = retries_max_wait_ms
172208

209+
self.oauth_client = None
210+
self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None)
211+
if self.bearer_auth_credentials_source is not None:
212+
self.auth = None
213+
headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id']
214+
missing_headers = [header for header in headers if header not in conf_copy]
215+
if missing_headers:
216+
raise ValueError("Missing required bearer configuration properties: {}"
217+
.format(", ".join(missing_headers)))
218+
219+
self.logical_cluster = conf_copy.pop('bearer.auth.logical.cluster')
220+
if not isinstance(self.logical_cluster, str):
221+
raise TypeError("logical cluster must be a str, not " + str(type(self.logical_cluster)))
222+
223+
self.identity_pool_id = conf_copy.pop('bearer.auth.identity.pool.id')
224+
if not isinstance(self.identity_pool_id, str):
225+
raise TypeError("identity pool id must be a str, not " + str(type(self.identity_pool_id)))
226+
227+
if self.bearer_auth_credentials_source == 'OAUTHBEARER':
228+
properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope',
229+
'bearer.auth.issuer.endpoint.url']
230+
missing_properties = [prop for prop in properties_list if prop not in conf_copy]
231+
if missing_properties:
232+
raise ValueError("Missing required OAuth configuration properties: {}".
233+
format(", ".join(missing_properties)))
234+
235+
self.client_id = conf_copy.pop('bearer.auth.client.id')
236+
if not isinstance(self.client_id, string_type):
237+
raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id)))
238+
239+
self.client_secret = conf_copy.pop('bearer.auth.client.secret')
240+
if not isinstance(self.client_secret, string_type):
241+
raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret)))
242+
243+
self.scope = conf_copy.pop('bearer.auth.scope')
244+
if not isinstance(self.scope, string_type):
245+
raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope)))
246+
247+
self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url')
248+
if not isinstance(self.token_endpoint, string_type):
249+
raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not "
250+
+ str(type(self.token_endpoint)))
251+
252+
self.oauth_client = _OAuthClient(self.client_id, self.client_secret, self.scope, self.token_endpoint,
253+
self.max_retries, self.retries_wait_ms, self.retries_max_wait_ms)
254+
255+
elif self.bearer_auth_credentials_source == 'STATIC_TOKEN':
256+
if 'bearer.auth.token' not in conf_copy:
257+
raise ValueError("Missing bearer.auth.token")
258+
self.bearer_token = conf_copy.pop('bearer.auth.token')
259+
if not isinstance(self.bearer_token, string_type):
260+
raise TypeError("bearer.auth.token must be a str, not " + str(type(self.bearer_token)))
261+
173262
# Any leftover keys are unknown to _RestClient
174263
if len(conf_copy) > 0:
175264
raise ValueError("Unrecognized properties: {}"
@@ -209,6 +298,14 @@ def __init__(self, conf: dict):
209298
timeout=self.timeout
210299
)
211300

301+
def handle_bearer_auth(self, headers: dict):
302+
token = self.bearer_token
303+
if self.oauth_client:
304+
token = self.oauth_client.get_access_token()
305+
headers["Authorization"] = "Bearer {}".format(token)
306+
headers['Confluent-Identity-Pool-Id'] = self.identity_pool_id
307+
headers['target-sr-cluster'] = self.logical_cluster
308+
212309
def get(self, url: str, query: Optional[dict] = None) -> Any:
213310
return self.send_request(url, method='GET', query=query)
214311

@@ -256,6 +353,9 @@ def send_request(
256353
headers = {'Content-Length': str(len(body)),
257354
'Content-Type': "application/vnd.schemaregistry.v1+json"}
258355

356+
if self.bearer_auth_credentials_source:
357+
self.handle_bearer_auth(headers)
358+
259359
response = None
260360
for i, base_url in enumerate(self.base_urls):
261361
try:

tests/schema_registry/test_config.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
TEST_URL = 'http://SchemaRegistry:65534'
2727
TEST_USERNAME = 'sr_user'
2828
TEST_USER_PASSWORD = 'sr_user_secret'
29+
TEST_POOL = 'sr_pool'
30+
TEST_CLUSTER = 'lsrc-1234'
31+
TEST_SCOPE = 'sr_scope'
32+
TEST_ENDPOINT = 'http://oauth_endpoint'
2933

3034
"""
3135
Tests to ensure all configurations are handled correctly.
@@ -112,6 +116,120 @@ def test_config_auth_userinfo_invalid():
112116
SchemaRegistryClient(conf)
113117

114118

119+
def test_bearer_config():
120+
conf = {'url': TEST_URL,
121+
'bearer.auth.credentials.source': "OAUTHBEARER"}
122+
123+
with pytest.raises(ValueError, match=r"Missing required bearer configuration properties: (.*)"):
124+
SchemaRegistryClient(conf)
125+
126+
127+
def test_oauth_bearer_config_missing():
128+
conf = {'url': TEST_URL,
129+
'bearer.auth.credentials.source': "OAUTHBEARER",
130+
'bearer.auth.logical.cluster': TEST_CLUSTER,
131+
'bearer.auth.identity.pool.id': TEST_POOL}
132+
133+
with pytest.raises(ValueError, match=r"Missing required OAuth configuration properties: (.*)"):
134+
SchemaRegistryClient(conf)
135+
136+
137+
def test_oauth_bearer_config_invalid():
138+
conf = {'url': TEST_URL,
139+
'bearer.auth.credentials.source': "OAUTHBEARER",
140+
'bearer.auth.logical.cluster': TEST_CLUSTER,
141+
'bearer.auth.identity.pool.id': 1}
142+
143+
with pytest.raises(TypeError, match=r"identity pool id must be a str, not (.*)"):
144+
SchemaRegistryClient(conf)
145+
146+
conf = {'url': TEST_URL,
147+
'bearer.auth.credentials.source': "OAUTHBEARER",
148+
'bearer.auth.logical.cluster': 1,
149+
'bearer.auth.identity.pool.id': TEST_POOL}
150+
151+
with pytest.raises(TypeError, match=r"logical cluster must be a str, not (.*)"):
152+
SchemaRegistryClient(conf)
153+
154+
conf = {'url': TEST_URL,
155+
'bearer.auth.credentials.source': "OAUTHBEARER",
156+
'bearer.auth.logical.cluster': TEST_CLUSTER,
157+
'bearer.auth.identity.pool.id': TEST_POOL,
158+
'bearer.auth.client.id': 1,
159+
'bearer.auth.client.secret': TEST_USER_PASSWORD,
160+
'bearer.auth.scope': TEST_SCOPE,
161+
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}
162+
163+
with pytest.raises(TypeError, match=r"bearer.auth.client.id must be a str, not (.*)"):
164+
SchemaRegistryClient(conf)
165+
166+
conf = {'url': TEST_URL,
167+
'bearer.auth.credentials.source': "OAUTHBEARER",
168+
'bearer.auth.logical.cluster': TEST_CLUSTER,
169+
'bearer.auth.identity.pool.id': TEST_POOL,
170+
'bearer.auth.client.id': TEST_USERNAME,
171+
'bearer.auth.client.secret': 1,
172+
'bearer.auth.scope': TEST_SCOPE,
173+
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}
174+
175+
with pytest.raises(TypeError, match=r"bearer.auth.client.secret must be a str, not (.*)"):
176+
SchemaRegistryClient(conf)
177+
178+
conf = {'url': TEST_URL,
179+
'bearer.auth.credentials.source': "OAUTHBEARER",
180+
'bearer.auth.logical.cluster': TEST_CLUSTER,
181+
'bearer.auth.identity.pool.id': TEST_POOL,
182+
'bearer.auth.client.id': TEST_USERNAME,
183+
'bearer.auth.client.secret': TEST_USER_PASSWORD,
184+
'bearer.auth.scope': 1,
185+
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}
186+
187+
with pytest.raises(TypeError, match=r"bearer.auth.scope must be a str, not (.*)"):
188+
SchemaRegistryClient(conf)
189+
190+
conf = {'url': TEST_URL,
191+
'bearer.auth.credentials.source': "OAUTHBEARER",
192+
'bearer.auth.logical.cluster': TEST_CLUSTER,
193+
'bearer.auth.identity.pool.id': TEST_POOL,
194+
'bearer.auth.client.id': TEST_USERNAME,
195+
'bearer.auth.client.secret': TEST_USER_PASSWORD,
196+
'bearer.auth.scope': TEST_SCOPE,
197+
'bearer.auth.issuer.endpoint.url': 1}
198+
199+
with pytest.raises(TypeError, match=r"bearer.auth.issuer.endpoint.url must be a str, not (.*)"):
200+
SchemaRegistryClient(conf)
201+
202+
203+
def test_oauth_bearer_config_valid():
204+
conf = {'url': TEST_URL,
205+
'bearer.auth.credentials.source': "OAUTHBEARER",
206+
'bearer.auth.logical.cluster': TEST_CLUSTER,
207+
'bearer.auth.identity.pool.id': TEST_POOL,
208+
'bearer.auth.client.id': TEST_USERNAME,
209+
'bearer.auth.client.secret': TEST_USER_PASSWORD,
210+
'bearer.auth.scope': TEST_SCOPE,
211+
'bearer.auth.issuer.endpoint.url': TEST_ENDPOINT}
212+
213+
client = SchemaRegistryClient(conf)
214+
215+
assert client._rest_client.logical_cluster == TEST_CLUSTER
216+
assert client._rest_client.identity_pool_id == TEST_POOL
217+
assert client._rest_client.client_id == TEST_USERNAME
218+
assert client._rest_client.client_secret == TEST_USER_PASSWORD
219+
assert client._rest_client.scope == TEST_SCOPE
220+
assert client._rest_client.token_endpoint == TEST_ENDPOINT
221+
222+
223+
def test_static_bearer_config():
224+
conf = {'url': TEST_URL,
225+
'bearer.auth.credentials.source': 'STATIC_TOKEN',
226+
'bearer.auth.logical.cluster': 'lsrc',
227+
'bearer.auth.identity.pool.id': 'pool_id'}
228+
229+
with pytest.raises(ValueError, match='Missing bearer.auth.token'):
230+
SchemaRegistryClient(conf)
231+
232+
115233
def test_config_unknown_prop():
116234
conf = {'url': TEST_URL,
117235
'basic.auth.credentials.source': 'SASL_INHERIT',
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import time
3+
from unittest.mock import Mock, patch
4+
5+
from confluent_kafka.schema_registry.schema_registry_client import _OAuthClient
6+
from confluent_kafka.schema_registry.error import OAuthTokenError
7+
8+
"""
9+
Tests to ensure OAuth client is set up correctly.
10+
11+
"""
12+
13+
14+
def test_expiry():
15+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
16+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1}
17+
assert not oauth_client.token_expired()
18+
time.sleep(1.5)
19+
assert oauth_client.token_expired()
20+
21+
22+
def test_get_token():
23+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 2, 1000, 20000)
24+
assert not oauth_client.token
25+
26+
def update_token1():
27+
oauth_client.token = {'expires_at': 0, 'expires_in': 1, 'access_token': '123'}
28+
29+
def update_token2():
30+
oauth_client.token = {'expires_at': time.time() + 2, 'expires_in': 1, 'access_token': '1234'}
31+
32+
oauth_client.generate_access_token = Mock(side_effect=update_token1)
33+
oauth_client.get_access_token()
34+
assert oauth_client.generate_access_token.call_count == 1
35+
assert oauth_client.token['access_token'] == '123'
36+
37+
oauth_client.generate_access_token = Mock(side_effect=update_token2)
38+
oauth_client.get_access_token()
39+
# Call count resets to 1 after reassigning generate_access_token
40+
assert oauth_client.generate_access_token.call_count == 1
41+
assert oauth_client.token['access_token'] == '1234'
42+
43+
oauth_client.get_access_token()
44+
assert oauth_client.generate_access_token.call_count == 1
45+
46+
47+
def test_generate_token_retry_logic():
48+
oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', 5, 1000, 20000)
49+
50+
with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep,
51+
patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter):
52+
53+
with pytest.raises(OAuthTokenError):
54+
oauth_client.generate_access_token()
55+
56+
assert mock_sleep.call_count == 5
57+
assert mock_jitter.call_count == 5

0 commit comments

Comments
 (0)