Skip to content

Commit cb355e7

Browse files
committed
add passphrase logic to jwt auth
1 parent a1428bf commit cb355e7

File tree

6 files changed

+144
-6
lines changed

6 files changed

+144
-6
lines changed

airbyte_cdk/sources/declarative/auth/jwt.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from typing import Any, Mapping, Optional, Union
1010

1111
import jwt
12+
from cryptography.hazmat.backends import default_backend
13+
from cryptography.hazmat.primitives import serialization
14+
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
1215

1316
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import DeclarativeAuthenticator
1417
from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
@@ -74,6 +77,7 @@ class JwtAuthenticator(DeclarativeAuthenticator):
7477
aud: Optional[Union[InterpolatedString, str]] = None
7578
additional_jwt_headers: Optional[Mapping[str, Any]] = None
7679
additional_jwt_payload: Optional[Mapping[str, Any]] = None
80+
passphrase: Optional[Union[InterpolatedString, str]] = None
7781

7882
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
7983
self._secret_key = InterpolatedString.create(self.secret_key, parameters=parameters)
@@ -103,6 +107,11 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
103107
self._additional_jwt_payload = InterpolatedMapping(
104108
self.additional_jwt_payload or {}, parameters=parameters
105109
)
110+
self._passphrase = (
111+
InterpolatedString.create(self.passphrase, parameters=parameters)
112+
if self.passphrase
113+
else None
114+
)
106115

107116
def _get_jwt_headers(self) -> dict[str, Any]:
108117
"""
@@ -149,16 +158,24 @@ def _get_jwt_payload(self) -> dict[str, Any]:
149158
payload["nbf"] = nbf
150159
return payload
151160

152-
def _get_secret_key(self) -> str:
161+
def _get_secret_key(self) -> PrivateKeyTypes | str | bytes:
153162
"""
154163
Returns the secret key used to sign the JWT.
155164
"""
156165
secret_key: str = self._secret_key.eval(self.config, json_loads=json.loads)
157-
return (
158-
base64.b64encode(secret_key.encode()).decode()
159-
if self._base64_encode_secret_key
160-
else secret_key
161-
)
166+
167+
if self._passphrase:
168+
return serialization.load_pem_private_key(
169+
secret_key.encode(),
170+
password=self._passphrase.eval(self.config, json_loads=json.loads).encode(),
171+
backend=default_backend(),
172+
)
173+
else:
174+
return (
175+
base64.b64encode(secret_key.encode()).decode()
176+
if self._base64_encode_secret_key
177+
else secret_key
178+
)
162179

163180
def _get_signed_token(self) -> Union[str, Any]:
164181
"""

airbyte_cdk/sources/declarative/declarative_component_schema.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,12 @@ definitions:
12701270
title: Additional JWT Payload Properties
12711271
description: Additional properties to be added to the JWT payload.
12721272
additionalProperties: true
1273+
passphrase:
1274+
title: Passphrase
1275+
description: A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.
1276+
type: string
1277+
examples:
1278+
- "{{ config['passphrase'] }}"
12731279
$parameters:
12741280
type: object
12751281
additionalProperties: true

airbyte_cdk/sources/declarative/interpolation/macros.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
import re
88
import typing
9+
import uuid
910
from typing import Optional, Union
1011
from urllib.parse import quote_plus
1112

@@ -207,6 +208,13 @@ def camel_case_to_snake_case(value: str) -> str:
207208
return re.sub(r"(?<!^)(?=[A-Z])", "_", value).lower()
208209

209210

211+
def random_uuid() -> str:
212+
"""
213+
Generates a UUID4
214+
"""
215+
return str(uuid.uuid4())
216+
217+
210218
_macros_list = [
211219
now_utc,
212220
today_utc,

airbyte_cdk/sources/declarative/models/declarative_component_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ class JwtAuthenticator(BaseModel):
448448
description="Additional properties to be added to the JWT payload.",
449449
title="Additional JWT Payload Properties",
450450
)
451+
passphrase: Optional[str] = Field(
452+
None,
453+
description="A passphrase/password used to encrypt the private key. Only provide a passphrase if required by the API for JWT authentication. The API will typically provide the passphrase when generating the public/private key pair.",
454+
examples=["{{ config['passphrase'] }}"],
455+
title="Passphrase",
456+
)
451457
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")
452458

453459

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2705,6 +2705,7 @@ def create_jwt_authenticator(
27052705
aud=jwt_payload.aud,
27062706
additional_jwt_headers=model.additional_jwt_headers,
27072707
additional_jwt_payload=model.additional_jwt_payload,
2708+
passphrase=model.passphrase,
27082709
)
27092710

27102711
def create_list_partition_router(

unit_tests/sources/declarative/auth/test_jwt.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import freezegun
99
import jwt
1010
import pytest
11+
from cryptography.hazmat.backends import default_backend
12+
from cryptography.hazmat.primitives import serialization
13+
from cryptography.hazmat.primitives.asymmetric import rsa
1114

1215
from airbyte_cdk.sources.declarative.auth.jwt import JwtAuthenticator
1316

@@ -185,3 +188,100 @@ def test_get_header_prefix(self, header_prefix, expected):
185188
header_prefix=header_prefix,
186189
)
187190
assert authenticator._get_header_prefix() == expected
191+
192+
def test_get_secret_key_with_passphrase(self):
193+
"""Test _get_secret_key method with encrypted private key and passphrase."""
194+
# Generate a test RSA private key
195+
private_key = rsa.generate_private_key(
196+
public_exponent=65537, key_size=2048, backend=default_backend()
197+
)
198+
199+
passphrase = b"test_passphrase"
200+
encrypted_pem = private_key.private_bytes(
201+
encoding=serialization.Encoding.PEM,
202+
format=serialization.PrivateFormat.PKCS8,
203+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
204+
)
205+
206+
authenticator = JwtAuthenticator(
207+
config={},
208+
parameters={},
209+
secret_key=encrypted_pem.decode(),
210+
algorithm="RS256",
211+
token_duration=1200,
212+
passphrase="test_passphrase",
213+
)
214+
215+
result_key = authenticator._get_secret_key()
216+
217+
assert isinstance(result_key, rsa.RSAPrivateKey)
218+
219+
original_public_key = private_key.public_key()
220+
result_public_key = result_key.public_key()
221+
222+
original_public_numbers = original_public_key.public_numbers()
223+
result_public_numbers = result_public_key.public_numbers()
224+
225+
assert original_public_numbers.n == result_public_numbers.n
226+
assert original_public_numbers.e == result_public_numbers.e
227+
228+
def test_get_secret_key_with_wrong_passphrase_raises_error(self):
229+
"""Test that _get_secret_key raises error with wrong passphrase."""
230+
private_key = rsa.generate_private_key(
231+
public_exponent=65537, key_size=2048, backend=default_backend()
232+
)
233+
234+
passphrase = b"correct_passphrase"
235+
encrypted_pem = private_key.private_bytes(
236+
encoding=serialization.Encoding.PEM,
237+
format=serialization.PrivateFormat.PKCS8,
238+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
239+
)
240+
241+
authenticator = JwtAuthenticator(
242+
config={},
243+
parameters={},
244+
secret_key=encrypted_pem.decode(),
245+
algorithm="RS256",
246+
token_duration=1200,
247+
passphrase="wrong_passphrase",
248+
)
249+
250+
with pytest.raises(Exception):
251+
authenticator._get_secret_key()
252+
253+
def test_get_signed_token_with_passphrase_protected_key(self):
254+
"""Test that JWT signing works with passphrase-protected RSA private key."""
255+
private_key = rsa.generate_private_key(
256+
public_exponent=65537, key_size=2048, backend=default_backend()
257+
)
258+
259+
passphrase = b"test_passphrase"
260+
encrypted_pem = private_key.private_bytes(
261+
encoding=serialization.Encoding.PEM,
262+
format=serialization.PrivateFormat.PKCS8,
263+
encryption_algorithm=serialization.BestAvailableEncryption(passphrase),
264+
)
265+
266+
authenticator = JwtAuthenticator(
267+
config={},
268+
parameters={},
269+
secret_key=encrypted_pem.decode(),
270+
algorithm="RS256",
271+
token_duration=1000,
272+
passphrase="test_passphrase",
273+
typ="JWT",
274+
iss="test_issuer",
275+
)
276+
277+
signed_token = authenticator._get_signed_token()
278+
279+
assert isinstance(signed_token, str)
280+
assert len(signed_token.split(".")) == 3
281+
282+
public_key = private_key.public_key()
283+
decoded_payload = jwt.decode(signed_token, public_key, algorithms=["RS256"])
284+
285+
assert decoded_payload["iss"] == "test_issuer"
286+
assert "iat" in decoded_payload
287+
assert "exp" in decoded_payload

0 commit comments

Comments
 (0)