Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ the [`JWT` authentication type](https://trino.io/docs/current/security/jwt.html)

### OAuth2 authentication

Make sure that the OAuth2 support is installed using `pip install trino[oauth]`.

#### Interactive Browser authentication

The `OAuth2Authentication` class can be used to connect to a Trino cluster configured with
the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.html).

Expand Down Expand Up @@ -248,14 +252,127 @@ The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` ins
from trino.auth import OAuth2Authentication

engine = create_engine(
"trino://<username>@<host>:<port>/<catalog>",
"trino://<username>@<host>:<port>/<catalog>",
connect_args={
"auth": OAuth2Authentication(),
"http_scheme": "https",
}
)
```

#### Client Credentials authentication

```python
from trino.dbapi import connect
from trino.auth import ClientCredentials
from trino.oauth2.models import OidcConfig

auth = ClientCredentials(
client_id="<client_id>",
client_secret="<client_secret>",
url_config=OidcConfig(
token_endpoint="<token_endpoint>",
# other endpoints if needed
),
scope="<number of scopes>", # optional
audience="<audience>", # optional
)

conn = connect(
user="<username>",
auth=auth,
http_scheme="https",
...
)
```

#### Device Code authentication

```python
from trino.dbapi import connect
from trino.auth import DeviceCode
from trino.oauth2.models import OidcConfig

auth = DeviceCode(
client_id="<client_id>",
url_config=OidcConfig(
token_endpoint="<token_endpoint>",
device_authorization_endpoint="<device_authorization_endpoint>",
),
scope="<scope>", # optional
audience="<audience>", # optional
)

conn = connect(
user="<username>",
auth=auth,
http_scheme="https",
...
)
```

#### Authorization Code authentication

```python
from trino.dbapi import connect
from trino.auth import AuthorizationCode
from trino.oauth2.models import OidcConfig

auth = AuthorizationCode(
client_id="<client_id>",
client_secret="<client_secret>", # optional
url_config=OidcConfig(
token_endpoint="<token_endpoint>",
authorization_endpoint="<authorization_endpoint>",
),
scope="<scope>", # optional
audience="<audience>", # optional
)

conn = connect(
user="<username>",
auth=auth,
http_scheme="https",
...
)
```

### Reference

For further details, please consult [Trino documentation](https://trino.io/docs/current).

### Secure Token Storage

By default all ClientCredentials, DeviceCode, AuthorizationCode JWT tokens are securely storaged
using the keyrings.cryptfile feature of [keyring library](https://pypi.org/project/keyring/).

Tokens are stored encrypted at ~/.local/share/python_keyring/cryptfile_pass.cfg

You can optionally use different keyring backends by supplying the `PYTHON_KEYRING_BACKEND` environment variable.

To use an encrypted file backend for credentials:

```bash
export KEYRING_CRYPTFILE_PASSWORD=your_secure_password
```

Or you can pass the password directly (less secure):

```python
conn = connect(
host="trino.example.com",
port=443,
auth=DeviceCode(
client_id="<CLIENT_ID>",
client_secret="<CLIENT_SECRET>",
url_config=OidcConfig(oidc_discovery_url="https://sso.example.com/.well-known/openid-configuration"),
token_storage_password="your_secure_password" # less secure
),
http_scheme="https"
)
```


### Certificate authentication

`CertificateAuthentication` class can be used to connect to Trino cluster configured with [certificate based authentication](https://trino.io/docs/current/security/certificate.html). `CertificateAuthentication` requires paths to a valid client certificate and private key.
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
"krb5 == 0.5.1"]
sqlalchemy_require = ["sqlalchemy >= 1.3"]
external_authentication_token_cache_require = ["keyring"]
oauth_require = ["trino.oauth2 @ git+https://github.com/dprophet/trino-python-oauth2"]

# We don't add localstorage_require to all_require as users must explicitly opt in to use keyring.
all_require = kerberos_require + sqlalchemy_require
all_require = kerberos_require + sqlalchemy_require + oauth_require

tests_require = all_require + gssapi_require + [
# httpretty >= 1.1 duplicates requests in `httpretty.latest_requests`
Expand Down Expand Up @@ -96,6 +97,7 @@
"all": all_require,
"kerberos": kerberos_require,
"gssapi": gssapi_require,
"oauth": oauth_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
"external-authentication-token-cache": external_authentication_token_cache_require,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@
from unittest.mock import MagicMock
from unittest.mock import patch

import keyring
import pytest

from tests.unit.oauth_test_utils import MockKeyring


@pytest.fixture(autouse=True, scope="session")
def setup_test_keyring():
mk = MockKeyring()
keyring.set_keyring(mk)
yield mk


@pytest.fixture(scope="session")
def sample_post_response_data():
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/oauth_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import namedtuple

import httpretty
import keyring.backend

from trino import constants

Expand Down Expand Up @@ -150,3 +151,50 @@ def get_token_callback(self, request, uri, response_headers):
if challenge.attempts == 0:
return [200, response_headers, f'{{"token": "{challenge.token}"}}']
return [200, response_headers, f'{{"nextUri": "{uri}"}}']


class MockKeyring(keyring.backend.KeyringBackend):
priority = 1

def __init__(self):
self.file_location = self._generate_test_root_dir()

@staticmethod
def _generate_test_root_dir():
import tempfile

return tempfile.mkdtemp(prefix="trino-python-client-unit-test-")

def _get_file_path(self, servicename, username):
from os.path import join

file_location = self.file_location
file_name = f"{servicename}_{username}.txt"
return join(file_location, file_name)

def set_password(self, servicename, username, password):
file_path = self._get_file_path(servicename, username)

with open(file_path, "w") as file:
file.write(password)

def get_password(self, servicename, username):
import os

file_path = self._get_file_path(servicename, username)
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
password = file.read()

return password

def delete_password(self, servicename, username):
import os

file_path = self._get_file_path(servicename, username)
if not os.path.exists(file_path):
return None

os.remove(file_path)
46 changes: 1 addition & 45 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from tests.unit.oauth_test_utils import RedirectHandlerWithException
from tests.unit.oauth_test_utils import SERVER_ADDRESS
from tests.unit.oauth_test_utils import TOKEN_RESOURCE
from tests.unit.oauth_test_utils import MockKeyring
from trino import __version__
from trino import constants
from trino.auth import _OAuth2KeyRingTokenCache
Expand Down Expand Up @@ -1405,48 +1406,3 @@ def test_store_long_password(self):

retrieved_password = cache.get_token_from_cache(host)
self.assertEqual(long_password, retrieved_password)


class MockKeyring(keyring.backend.KeyringBackend):
def __init__(self):
self.file_location = self._generate_test_root_dir()

@staticmethod
def _generate_test_root_dir():
import tempfile

return tempfile.mkdtemp(prefix="trino-python-client-unit-test-")

def file_path(self, servicename, username):
from os.path import join

file_location = self.file_location
file_name = f"{servicename}_{username}.txt"
return join(file_location, file_name)

def set_password(self, servicename, username, password):
file_path = self.file_path(servicename, username)

with open(file_path, "w") as file:
file.write(password)

def get_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
password = file.read()

return password

def delete_password(self, servicename, username):
import os

file_path = self.file_path(servicename, username)
if not os.path.exists(file_path):
return None

os.remove(file_path)
15 changes: 12 additions & 3 deletions tests/unit/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from unittest.mock import patch

import httpretty
import keyring
import pytest
from httpretty import httprettified
from requests import Session

from tests.unit.oauth_test_utils import _get_token_requests
from tests.unit.oauth_test_utils import _post_statement_requests
from tests.unit.oauth_test_utils import GetTokenCallback
from tests.unit.oauth_test_utils import MockKeyring
from tests.unit.oauth_test_utils import PostStatementCallback
from tests.unit.oauth_test_utils import REDIRECT_RESOURCE
from tests.unit.oauth_test_utils import RedirectHandler
Expand Down Expand Up @@ -58,8 +60,15 @@ def test_http_session_is_defaulted_when_not_specified(mock_client):
assert mock_client.TrinoRequest.http.Session.return_value in request_args


@pytest.fixture
def mock_keyring():
mk = MockKeyring()
keyring.set_keyring(mk)
return mk


@httprettified
def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data):
def test_token_retrieved_once_per_auth_instance(mock_keyring, sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

Expand Down Expand Up @@ -123,7 +132,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl


@httprettified
def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data,
def test_token_retrieved_once_when_authentication_instance_is_shared(mock_keyring, sample_post_response_data,
sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())
Expand Down Expand Up @@ -189,7 +198,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post


@httprettified
def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data):
def test_token_retrieved_once_when_multithreaded(mock_keyring, sample_post_response_data, sample_get_response_data):
token = str(uuid.uuid4())
challenge_id = str(uuid.uuid4())

Expand Down
2 changes: 2 additions & 0 deletions trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

from . import auth
from . import client
from . import constants
Expand Down
Loading