Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def get_conn(self) -> Connection:
elif db.password:
auth = trino.auth.BasicAuthentication(db.login, db.password) # type: ignore[attr-defined]
elif extra.get("auth") == "jwt":
auth = trino.auth.JWTAuthentication(token=extra.get("jwt__token"))
if "jwt__file" in extra:
with open(extra.get("jwt__file")) as jwt_file:
token = jwt_file.read()
else:
token = extra.get("jwt__token")
auth = trino.auth.JWTAuthentication(token=token)
elif extra.get("auth") == "certs":
auth = trino.auth.CertificateAuthentication(
extra.get("certs__client_cert_path"),
Expand Down
3 changes: 3 additions & 0 deletions docs/apache-airflow-providers-trino/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ Extra (optional, connection parameters)
The following extra parameters can be used to configure authentication:

* ``jwt__token`` - If jwt authentication should be used, the value of token is given via this parameter.
* ``jwt__file`` - If jwt authentication should be used, the location on disk for the file containing the jwt token.
* ``certs__client_cert_path``, ``certs__client_key_path``- If certificate authentication should be used, the path to the client certificate and key is given via these parameters.
* ``kerberos__service_name``, ``kerberos__config``, ``kerberos__mutual_authentication``, ``kerberos__force_preemptive``, ``kerberos__hostname_override``, ``kerberos__sanitize_mutual_error_response``, ``kerberos__principal``,``kerberos__delegate``, ``kerberos__ca_bundle`` - These parameters can be set when enabling ``kerberos`` authentication.
* ``session_properties`` - JSON dictionary which allows to set session_properties. Example: ``{'session_properties':{'scale_writers':true,'task_writer_count:1'}}``
* ``client_tags`` - List of comma separated tags. Example ``{'client_tags':['sales','cluster1']}```

Note: If ``jwt__file`` and ``jwt__token`` are both given, ``jwt__file`` will take precedent.
30 changes: 30 additions & 0 deletions tests/providers/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from __future__ import annotations

import json
import os
import re
from tempfile import TemporaryDirectory
from unittest import mock
from unittest.mock import patch

Expand All @@ -37,6 +39,19 @@
CERT_AUTHENTICATION = "airflow.providers.trino.hooks.trino.trino.auth.CertificateAuthentication"


@pytest.fixture()
def jwt_token_file():
# Couldn't get this working with TemporaryFile, using TemporaryDirectory instead
# Save a phony jwt to a temporary file for the trino hook to read from
with TemporaryDirectory() as tmp_dir:
tmp_jwt_file = os.path.join(tmp_dir, "jwt.json")

with open(tmp_jwt_file, "w") as tmp_file:
tmp_file.write('{"phony":"jwt"}')

yield tmp_jwt_file


class TestTrinoHookConn:
@patch(BASIC_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
Expand Down Expand Up @@ -110,6 +125,21 @@ def test_get_conn_jwt_auth(self, mock_get_connection, mock_connect, mock_jwt_aut
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)

@patch(JWT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
def test_get_conn_jwt_file(self, mock_get_connection, mock_connect, mock_jwt_auth, jwt_token_file):
extras = {
"auth": "jwt",
"jwt__file": jwt_token_file,
}
self.set_get_connection_return_value(
mock_get_connection,
extra=json.dumps(extras),
)
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_jwt_auth)

@patch(CERT_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
Expand Down