Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
593 changes: 0 additions & 593 deletions .github/workflows/build_test.yml

This file was deleted.

Binary file not shown.
71 changes: 71 additions & 0 deletions .github/workflows/run_single_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: Run custom pytest

on:
push:

jobs:
run-pytest:
strategy:
matrix:
cloud-provider: [aws, azure, gcp]
os:
- image: ubuntu-latest
id: lububuntu
- image: windows-latest
id: windows
- image: macos-latest
id: macos
python-version: ["3.10"]
name: Custom pytest on ${{ matrix.os.id }}-py${{ matrix.python-version }}-${{ matrix.cloud-provider }}
runs-on: ${{ matrix.os.image }}
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Set up Java
uses: actions/setup-java@v4 # for wiremock
with:
java-version: 11
distribution: 'temurin'
java-package: 'jre'

- name: Fetch Wiremock
shell: bash
run: curl https://repo1.maven.org/maven2/org/wiremock/wiremock-standalone/3.11.0/wiremock-standalone-3.11.0.jar --output .wiremock/wiremock-standalone.jar

- name: Setup parameters file
shell: bash
env:
PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
run: |
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
.github/workflows/parameters/public/parameters_${{ matrix.cloud-provider }}.py.gpg > test/parameters.py

- name: Setup private key file (old)
shell: bash
env:
PYTHON_PRIVATE_KEY_SECRET: ${{ secrets.PYTHON_PRIVATE_KEY_SECRET }}
run: |
gpg --quiet --batch --yes --decrypt --passphrase="$PYTHON_PRIVATE_KEY_SECRET" \
.github/workflows/parameters/public/rsa_keys/rsa_key_python_${{ matrix.cloud-provider }}.p8.gpg > test/rsa_key_python_${{ matrix.cloud-provider }}.p8
# - name: Setup private key file (main)
# shell: bash
# env:
# PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }}
# run: |
# gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
# .github/workflows/parameters/public/rsa_keys/rsa_key_python_${{ matrix.cloud-provider }}.p8.gpg > test/rsa_key_python_${{ matrix.cloud-provider }}.p8

- name: Install dependencies
run: |
python -m pip install uv
python -m uv pip install ".[development,aio,secure-local-storage,pandas]"

- name: Run pytest
run: |
pytest -n auto -vv test/unit/aio/test_connection_async_unit.py
10 changes: 2 additions & 8 deletions src/snowflake/connector/aio/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,14 @@ async def __open_connection(self):
backoff_generator=self._backoff_generator,
)
elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
self._check_oauth_parameters()
if self._role and (self._oauth_scope == ""):
# if role is known then let's inject it into scope
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
self.auth_class = AuthByOauthCode(
application=self.application,
client_id=self._oauth_client_id,
client_secret=self._oauth_client_secret,
host=self.host,
authentication_url=self._oauth_authorization_url.format(
host=self.host, port=self.port
),
Expand All @@ -337,7 +337,6 @@ async def __open_connection(self):
enable_single_use_refresh_tokens=self._oauth_enable_single_use_refresh_tokens,
)
elif self._authenticator == OAUTH_CLIENT_CREDENTIALS:
self._check_oauth_parameters()
if self._role and (self._oauth_scope == ""):
# if role is known then let's inject it into scope
self._oauth_scope = _OAUTH_DEFAULT_SCOPE.format(role=self._role)
Expand All @@ -349,12 +348,7 @@ async def __open_connection(self):
host=self.host, port=self.port
),
scope=self._oauth_scope,
token_cache=(
auth.get_token_cache()
if self._client_store_temporary_credential
else None
),
refresh_token_enabled=self._oauth_enable_refresh_tokens,
connection=self,
)
elif self._authenticator == PROGRAMMATIC_ACCESS_TOKEN:
self.auth_class = AuthByPAT(self._token)
Expand Down
7 changes: 6 additions & 1 deletion src/snowflake/connector/aio/_result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def __init__(
result_chunks: list[JSONResultBatch] | list[ArrowResultBatch],
prefetch_thread_num: int,
) -> None:
super().__init__(cursor, result_chunks, prefetch_thread_num)
super().__init__(
cursor,
result_chunks,
prefetch_thread_num,
use_mp=False, # async code depends on aio rather than multiprocessing
)
self.batches = cast(
Union[list[JSONResultBatch], list[ArrowResultBatch]], self.batches
)
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/connector/aio/auth/_oauth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def __init__(
token_request_url: str,
redirect_uri: str,
scope: str,
host: str,
pkce_enabled: bool = True,
token_cache: TokenCache | None = None,
refresh_token_enabled: bool = False,
external_browser_timeout: int | None = None,
enable_single_use_refresh_tokens: bool = False,
connection: SnowflakeConnection | None = None,
**kwargs,
) -> None:
"""Initializes an instance with OAuth authorization code parameters."""
Expand All @@ -49,11 +51,13 @@ def __init__(
token_request_url=token_request_url,
redirect_uri=redirect_uri,
scope=scope,
host=host,
pkce_enabled=pkce_enabled,
token_cache=token_cache,
refresh_token_enabled=refresh_token_enabled,
external_browser_timeout=external_browser_timeout,
enable_single_use_refresh_tokens=enable_single_use_refresh_tokens,
connection=connection,
**kwargs,
)

Expand Down
7 changes: 2 additions & 5 deletions src/snowflake/connector/aio/auth/_oauth_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ...auth.oauth_credentials import (
AuthByOauthCredentials as AuthByOauthCredentialsSync,
)
from ...token_cache import TokenCache
from ._by_plugin import AuthByPlugin as AuthByPluginAsync

if TYPE_CHECKING:
Expand All @@ -27,8 +26,7 @@ def __init__(
client_secret: str,
token_request_url: str,
scope: str,
token_cache: TokenCache | None = None,
refresh_token_enabled: bool = False,
connection: SnowflakeConnection | None = None,
**kwargs,
) -> None:
"""Initializes an instance with OAuth client credentials parameters."""
Expand All @@ -42,8 +40,7 @@ def __init__(
client_secret=client_secret,
token_request_url=token_request_url,
scope=scope,
token_cache=token_cache,
refresh_token_enabled=refresh_token_enabled,
connection=connection,
**kwargs,
)

Expand Down
35 changes: 34 additions & 1 deletion src/snowflake/connector/auth/_oauth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from typing import TYPE_CHECKING, Any
from urllib.error import HTTPError, URLError

from ..errorcode import ER_FAILED_TO_REQUEST, ER_IDP_CONNECTION_ERROR
from ..errorcode import (
ER_FAILED_TO_REQUEST,
ER_IDP_CONNECTION_ERROR,
ER_NO_CLIENT_ID,
ER_NO_CLIENT_SECRET,
)
from ..errors import Error, ProgrammingError
from ..network import OAUTH_AUTHENTICATOR
from ..secret_detector import SecretDetector
from ..token_cache import TokenCache, TokenKey, TokenType
Expand Down Expand Up @@ -185,6 +191,33 @@ def assertion_content(self) -> str:
"""Returns the token."""
return self._access_token or ""

@staticmethod
def _validate_client_credentials_present(
client_id: str, client_secret: str, connection: SnowflakeConnection
) -> tuple[str, str]:
if client_id is None or client_id == "":
Error.errorhandler_wrapper(
connection,
None,
ProgrammingError,
{
"msg": "Oauth code flow requirement 'client_id' is empty",
"errno": ER_NO_CLIENT_ID,
},
)
if client_secret is None or client_secret == "":
Error.errorhandler_wrapper(
connection,
None,
ProgrammingError,
{
"msg": "Oauth code flow requirement 'client_secret' is empty",
"errno": ER_NO_CLIENT_SECRET,
},
)

return client_id, client_secret

def reauthenticate(
self,
*,
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/connector/auth/by_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class AuthType(Enum):
PAT = "PROGRAMMATIC_ACCESS_TOKEN"
NO_AUTH = "NO_AUTH"
WORKLOAD_IDENTITY = "WORKLOAD_IDENTITY"
PAT_WITH_EXTERNAL_SESSION = "PAT_WITH_EXTERNAL_SESSION"


class AuthByPlugin(ABC):
Expand Down
92 changes: 92 additions & 0 deletions src/snowflake/connector/auth/oauth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from ..compat import parse_qs, urlparse, urlsplit
from ..constants import OAUTH_TYPE_AUTHORIZATION_CODE
from ..errorcode import (
ER_INVALID_VALUE,
ER_OAUTH_CALLBACK_ERROR,
ER_OAUTH_SERVER_TIMEOUT,
ER_OAUTH_STATE_CHANGED,
ER_UNABLE_TO_OPEN_BROWSER,
)
from ..errors import Error, ProgrammingError
from ..token_cache import TokenCache
from ._http_server import AuthHttpServer
from ._oauth_base import AuthByOAuthBase
Expand All @@ -45,6 +47,8 @@ def _get_query_params(
class AuthByOauthCode(AuthByOAuthBase):
"""Authenticates user by OAuth code flow."""

_LOCAL_APPLICATION_CLIENT_CREDENTIALS = "LOCAL_APPLICATION"

def __init__(
self,
application: str,
Expand All @@ -54,13 +58,27 @@ def __init__(
token_request_url: str,
redirect_uri: str,
scope: str,
host: str,
pkce_enabled: bool = True,
token_cache: TokenCache | None = None,
refresh_token_enabled: bool = False,
external_browser_timeout: int | None = None,
enable_single_use_refresh_tokens: bool = False,
connection: SnowflakeConnection | None = None,
**kwargs,
) -> None:
authentication_url, redirect_uri = self._validate_oauth_code_uris(
authentication_url, redirect_uri, connection
)
client_id, client_secret = self._validate_client_credentials_with_defaults(
client_id,
client_secret,
authentication_url,
token_request_url,
host,
connection,
)

super().__init__(
client_id=client_id,
client_secret=client_secret,
Expand Down Expand Up @@ -385,3 +403,77 @@ def _parse_authorization_redirected_request(
},
)
return parsed.get("code", [None])[0], parsed.get("state", [None])[0]

@staticmethod
def _is_snowflake_as_idp(
authentication_url: str, token_request_url: str, host: str
) -> bool:
return (authentication_url == "" or host in authentication_url) and (
token_request_url == "" or host in token_request_url
)

def _eligible_for_default_client_credentials(
self,
client_id: str,
client_secret: str,
authorization_url: str,
token_request_url: str,
host: str,
) -> bool:
return (
(client_id == "" or client_secret is None)
and (client_secret == "" or client_secret is None)
and self.__class__._is_snowflake_as_idp(
authorization_url, token_request_url, host
)
)

def _validate_client_credentials_with_defaults(
self,
client_id: str,
client_secret: str,
authorization_url: str,
token_request_url: str,
host: str,
connection: SnowflakeConnection,
) -> tuple[str, str] | None:
if self._eligible_for_default_client_credentials(
client_id, client_secret, authorization_url, token_request_url, host
):
return (
self.__class__._LOCAL_APPLICATION_CLIENT_CREDENTIALS,
self.__class__._LOCAL_APPLICATION_CLIENT_CREDENTIALS,
)
else:
self._validate_client_credentials_present(
client_id, client_secret, connection
)
return client_id, client_secret

@staticmethod
def _validate_oauth_code_uris(
authorization_url: str, redirect_uri: str, connection: SnowflakeConnection
) -> tuple[str, str]:
if authorization_url and not authorization_url.startswith("https://"):
Error.errorhandler_wrapper(
connection,
None,
ProgrammingError,
{
"msg": "OAuth supports only authorization urls that use 'https' scheme",
"errno": ER_INVALID_VALUE,
},
)
if redirect_uri and not (
redirect_uri.startswith("http://") or redirect_uri.startswith("https://")
):
Error.errorhandler_wrapper(
connection,
None,
ProgrammingError,
{
"msg": "OAuth supports only authorization urls that use 'http(s)' scheme",
"errno": ER_INVALID_VALUE,
},
)
return authorization_url, redirect_uri
9 changes: 4 additions & 5 deletions src/snowflake/connector/auth/oauth_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import TYPE_CHECKING, Any

from ..constants import OAUTH_TYPE_CLIENT_CREDENTIALS
from ..token_cache import TokenCache
from ._oauth_base import AuthByOAuthBase

if TYPE_CHECKING:
Expand All @@ -27,17 +26,17 @@ def __init__(
client_secret: str,
token_request_url: str,
scope: str,
token_cache: TokenCache | None = None,
refresh_token_enabled: bool = False,
connection: SnowflakeConnection | None = None,
**kwargs,
) -> None:
self._validate_client_credentials_present(client_id, client_secret, connection)
super().__init__(
client_id=client_id,
client_secret=client_secret,
token_request_url=token_request_url,
scope=scope,
token_cache=token_cache,
refresh_token_enabled=refresh_token_enabled,
token_cache=None,
refresh_token_enabled=False,
**kwargs,
)
self._application = application
Expand Down
Loading
Loading