Skip to content

Commit 5b10ee9

Browse files
feat(JwtAuthenticator): add passphrase support to jwt auth (#773)
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent a1428bf commit 5b10ee9

File tree

7 files changed

+179
-2
lines changed

7 files changed

+179
-2
lines changed

airbyte_cdk/sources/declarative/auth/jwt.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,26 @@
66
import json
77
from dataclasses import InitVar, dataclass
88
from datetime import datetime
9-
from typing import Any, Mapping, Optional, Union
9+
from typing import Any, Mapping, Optional, Union, cast
1010

1111
import jwt
12+
from cryptography.hazmat.primitives import serialization
13+
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
14+
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PrivateKey
15+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
16+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
17+
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
1218

1319
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
1420
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
1521
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
1622
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
1723

24+
# Type alias for keys that JWT library accepts
25+
JwtKeyTypes = Union[
26+
RSAPrivateKey, EllipticCurvePrivateKey, Ed25519PrivateKey, Ed448PrivateKey, str, bytes
27+
]
28+
1829

1930
class JwtAlgorithm(str):
2031
"""
@@ -74,6 +85,7 @@ class JwtAuthenticator(DeclarativeAuthenticator):
7485
aud: Optional[Union[InterpolatedString, str]] = None
7586
additional_jwt_headers: Optional[Mapping[str, Any]] = None
7687
additional_jwt_payload: Optional[Mapping[str, Any]] = None
88+
passphrase: Optional[Union[InterpolatedString, str]] = None
7789

7890
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7991
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
@@ -103,6 +115,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
103115
self._additional_jwt_payload = InterpolatedMapping(
104116
self.additional_jwt_payload or {}, parameters=parameters
105117
)
118+
self._passphrase = (
119+
InterpolatedString.create(self.passphrase, parameters=parameters)
120+
if self.passphrase
121+
else None
122+
)
106123

107124
def _get_jwt_headers(self) -> dict[str, Any]:
108125
"""
@@ -149,11 +166,21 @@ def _get_jwt_payload(self) -> dict[str, Any]:
149166
payload["nbf"] = nbf
150167
return payload
151168

152-
def _get_secret_key(self) -> str:
169+
def _get_secret_key(self) -> JwtKeyTypes:
153170
"""
154171
Returns the secret key used to sign the JWT.
155172
"""
156173
secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads)
174+
175+
if self._passphrase:
176+
passphrase_value = self._passphrase.eval(self.config, json_loads=json.loads)
177+
if passphrase_value:
178+
private_key = serialization.load_pem_private_key(
179+
secret_key.encode(),
180+
password=passphrase_value.encode(),
181+
)
182+
return cast(JwtKeyTypes, private_key)
183+
157184
return (
158185
base64.b64encode(secret_key.encode()).decode()
159186
if self._base64_encode_secret_key

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,12 @@ definitions:
12701270
title: Additional JWT Payload Properties
12711271
description: Additional properties to be added to the JWT payload.
12721272
additionalProperties: true
1273+
passphrase:
1274+
title: Passphrase
1275+
description: A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.
1276+
type: string
1277+
examples:
1278+
- "{{ config['passphrase'] }}"
12731279
$parameters:
12741280
type: object
12751281
additionalProperties: true

airbyte_cdk/sources/declarative/interpolation/macros.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
import re
88
import typing
9+
import uuid
910
from typing import Optional, Union
1011
from urllib.parse import quote_plus
1112

@@ -207,6 +208,16 @@ def camel_case_to_snake_case(value: str) -> str:
207208
return re.sub(r"(?<!^)(?=[A-Z])", "_", value).lower()
208209

209210

211+
def generate_uuid() -> str:
212+
"""
213+
Generates a UUID4
214+
215+
Usage:
216+
`"{{ generate_uuid() }}"`
217+
"""
218+
return str(uuid.uuid4())
219+
220+
210221
_macros_list = [
211222
now_utc,
212223
today_utc,
@@ -220,5 +231,6 @@ def camel_case_to_snake_case(value: str) -> str:
220231
str_to_datetime,
221232
sanitize_url,
222233
camel_case_to_snake_case,
234+
generate_uuid,
223235
]
224236
macros = {f.__name__: f for f in _macros_list}

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ class JwtAuthenticator(BaseModel):
448448
description="Additional properties to be added to the JWT payload.",
449449
title="Additional JWT Payload Properties",
450450
)
451+
passphrase: Optional[str] = Field(
452+
None,
453+
description="A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.",
454+
examples=["{{ config['passphrase'] }}"],
455+
title="Passphrase",
456+
)
451457
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
452458

453459

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,7 @@ def create_jwt_authenticator(
27052705
aud=jwt_payload.aud,
27062706
additional_jwt_headers=model.additional_jwt_headers,
27072707
additional_jwt_payload=model.additional_jwt_payload,
2708+
passphrase=model.passphrase,
27082709
)
27092710

27102711
def create_list_partition_router(

unit_tests/sources/declarative/auth/test_jwt.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import freezegun
99
import jwt
1010
import pytest
11+
from cryptography.hazmat.backends import default_backend
12+
from cryptography.hazmat.primitives import serialization
13+
from cryptography.hazmat.primitives.asymmetric import rsa
1114

1215
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
1316

@@ -185,3 +188,100 @@ def test_get_header_prefix(self, header_prefix, expected):
185188
header_prefix=header_prefix,
186189
)
187190
assert authenticator._get_header_prefix() == expected
191+
192+
def test_get_secret_key_with_passphrase(self):
193+
"""Test _get_secret_key method with encrypted private key and passphrase."""
194+
# Generate a test RSA private key
195+
private_key = rsa.generate_private_key(
196+
public_exponent=65537, key_size=2048, backend=default_backend()
197+
)
198+
199+
passphrase = b"test_passphrase"
200+
encrypted_pem = private_key.private_bytes(
201+
encoding=serialization.Encoding.PEM,
202+
format=serialization.PrivateFormat.PKCS8,
203+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
204+
)
205+
206+
authenticator = JwtAuthenticator(
207+
config={},
208+
parameters={},
209+
secret_key=encrypted_pem.decode(),
210+
algorithm="RS256",
211+
token_duration=1200,
212+
passphrase="test_passphrase",
213+
)
214+
215+
result_key = authenticator._get_secret_key()
216+
217+
assert isinstance(result_key, rsa.RSAPrivateKey)
218+
219+
original_public_key = private_key.public_key()
220+
result_public_key = result_key.public_key()
221+
222+
original_public_numbers = original_public_key.public_numbers()
223+
result_public_numbers = result_public_key.public_numbers()
224+
225+
assert original_public_numbers.n == result_public_numbers.n
226+
assert original_public_numbers.e == result_public_numbers.e
227+
228+
def test_get_secret_key_with_wrong_passphrase_raises_error(self):
229+
"""Test that _get_secret_key raises error with wrong passphrase."""
230+
private_key = rsa.generate_private_key(
231+
public_exponent=65537, key_size=2048, backend=default_backend()
232+
)
233+
234+
passphrase = b"correct_passphrase"
235+
encrypted_pem = private_key.private_bytes(
236+
encoding=serialization.Encoding.PEM,
237+
format=serialization.PrivateFormat.PKCS8,
238+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
239+
)
240+
241+
authenticator = JwtAuthenticator(
242+
config={},
243+
parameters={},
244+
secret_key=encrypted_pem.decode(),
245+
algorithm="RS256",
246+
token_duration=1200,
247+
passphrase="wrong_passphrase",
248+
)
249+
250+
with pytest.raises(Exception):
251+
authenticator._get_secret_key()
252+
253+
def test_get_signed_token_with_passphrase_protected_key(self):
254+
"""Test that JWT signing works with passphrase-protected RSA private key."""
255+
private_key = rsa.generate_private_key(
256+
public_exponent=65537, key_size=2048, backend=default_backend()
257+
)
258+
259+
passphrase = b"test_passphrase"
260+
encrypted_pem = private_key.private_bytes(
261+
encoding=serialization.Encoding.PEM,
262+
format=serialization.PrivateFormat.PKCS8,
263+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
264+
)
265+
266+
authenticator = JwtAuthenticator(
267+
config={},
268+
parameters={},
269+
secret_key=encrypted_pem.decode(),
270+
algorithm="RS256",
271+
token_duration=1000,
272+
passphrase="test_passphrase",
273+
typ="JWT",
274+
iss="test_issuer",
275+
)
276+
277+
signed_token = authenticator._get_signed_token()
278+
279+
assert isinstance(signed_token, str)
280+
assert len(signed_token.split(".")) == 3
281+
282+
public_key = private_key.public_key()
283+
decoded_payload = jwt.decode(signed_token, public_key, algorithms=["RS256"])
284+
285+
assert decoded_payload["iss"] == "test_issuer"
286+
assert "iat" in decoded_payload
287+
assert "exp" in decoded_payload

unit_tests/sources/declarative/interpolation/test_macros.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44

55
import datetime
6+
import uuid
67

78
import pytest
89

@@ -20,6 +21,7 @@
2021
("test_format_datetime", "format_datetime", True),
2122
("test_duration", "duration", True),
2223
("test_camel_case_to_snake_case", "camel_case_to_snake_case", True),
24+
("test_generate_uuid", "generate_uuid", True),
2325
("test_not_a_macro", "thisisnotavalidmacro", False),
2426
],
2527
)
@@ -275,3 +277,26 @@ def test_sanitize_url(test_name, input_value, expected_output):
275277
)
276278
def test_camel_case_to_snake_case(value, expected_value):
277279
assert macros["camel_case_to_snake_case"](value) == expected_value
280+
281+
282+
def test_generate_uuid():
283+
"""Test uuid macro generates valid UUID4 strings."""
284+
uuid_fn = macros["generate_uuid"]
285+
286+
# Test that uuid function returns a string
287+
result = uuid_fn()
288+
assert isinstance(result, str)
289+
290+
# Test that the result is a valid UUID format
291+
# This will raise ValueError if not a valid UUID
292+
parsed_uuid = uuid.UUID(result)
293+
294+
# Test that it's specifically a UUID4 (version 4)
295+
assert parsed_uuid.version == 4
296+
297+
# Test that multiple calls return different UUIDs
298+
result2 = uuid_fn()
299+
assert result != result2
300+
301+
# Test that both results are valid UUIDs
302+
uuid.UUID(result2) # Will raise ValueError if invalid

0 commit comments

Comments
 (0)