Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
_DOMAIN_NAME_MAP,
_OAUTH_DEFAULT_SCOPE,
ENV_VAR_PARTNER,
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
PARAMETER_AUTOCOMMIT,
PARAMETER_CLIENT_PREFETCH_THREADS,
PARAMETER_CLIENT_REQUEST_MFA_TOKEN,
Expand Down Expand Up @@ -242,6 +243,10 @@ def _get_private_bytes_from_file(
"internal_application_version": (CLIENT_VERSION, (type(None), str)),
"disable_ocsp_checks": (False, bool),
"ocsp_fail_open": (True, bool), # fail open on ocsp issues, default true
"ocsp_root_certs_dict_lock_timeout": (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT, # no timeout
int,
),
"inject_client_pause": (0, int), # snowflake internal
"session_parameters": (None, (type(None), dict)), # snowflake session parameters
"autocommit": (None, (type(None), bool)), # snowflake
Expand Down Expand Up @@ -443,6 +448,7 @@ class SnowflakeConnection:
validates the TLS certificate but doesn't check revocation status with OCSP provider.
ocsp_fail_open: Whether or not the connection is in fail open mode. Fail open mode decides if TLS certificates
continue to be validated. Revoked certificates are blocked. Any other exceptions are disregarded.
ocsp_root_certs_dict_lock_timeout: Timeout for the OCSP root certs dict lock in seconds. Default value is -1, which means no timeout.
session_id: The session ID of the connection.
user: The user name used in the connection.
host: The host name the connection attempts to connect to.
Expand Down Expand Up @@ -1545,6 +1551,8 @@ def __config(self, **kwargs):
WORKLOAD_IDENTITY_AUTHENTICATOR,
PROGRAMMATIC_ACCESS_TOKEN,
PAT_WITH_EXTERNAL_SESSION,
OAUTH_AUTHORIZATION_CODE,
OAUTH_CLIENT_CREDENTIALS,
}

if not (self._master_token and self._session_token):
Expand Down
3 changes: 3 additions & 0 deletions src/snowflake/connector/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ class FileHeader(NamedTuple):

HTTP_HEADER_VALUE_OCTET_STREAM = "application/octet-stream"

# OCSP
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT: int = -1


@unique
class OCSPMode(Enum):
Expand Down
7 changes: 7 additions & 0 deletions src/snowflake/connector/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
HTTP_HEADER_CONTENT_TYPE,
HTTP_HEADER_SERVICE_NAME,
HTTP_HEADER_USER_AGENT,
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
)
from .description import (
CLIENT_NAME,
Expand Down Expand Up @@ -337,6 +338,12 @@ def __init__(
ssl_wrap_socket.FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME = (
self._connection._ocsp_response_cache_filename if self._connection else None
)
# OCSP root timeout
ssl_wrap_socket.FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT = (
self._connection._ocsp_root_certs_dict_lock_timeout
if self._connection
else OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
)

# This is to address the issue where requests hangs
_ = "dummy".encode("idna").decode("utf-8")
Expand Down
131 changes: 73 additions & 58 deletions src/snowflake/connector/ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from . import constants
from .backoff_policies import exponential_backoff
from .cache import CacheEntry, SFDictCache, SFDictFileCache
from .constants import OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
from .telemetry import TelemetryField, generate_telemetry_data_dict
from .url_util import extract_top_level_domain_from_hostname, url_encode_str
from .util_text import _base64_bytes_to_str
Expand Down Expand Up @@ -1037,6 +1038,7 @@ def __init__(
use_ocsp_cache_server=None,
use_post_method: bool = True,
use_fail_open: bool = True,
root_certs_dict_lock_timeout: int = OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT,
**kwargs,
) -> None:
self.test_mode = os.getenv("SF_OCSP_TEST_MODE", None)
Expand All @@ -1045,6 +1047,7 @@ def __init__(
logger.debug("WARNING - DRIVER CONFIGURED IN TEST MODE")

self._use_post_method = use_post_method
self._root_certs_dict_lock_timeout = root_certs_dict_lock_timeout
self.OCSP_CACHE_SERVER = OCSPServer(
top_level_domain=extract_top_level_domain_from_hostname(
kwargs.pop("hostname", None)
Expand Down Expand Up @@ -1415,67 +1418,79 @@ def _check_ocsp_response_cache_server(

def _lazy_read_ca_bundle(self) -> None:
"""Reads the local cabundle file and cache it in memory."""
with SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK:
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return

lock_acquired = SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.acquire(
timeout=self._root_certs_dict_lock_timeout
)
if lock_acquired:
try:
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
"CURL_CA_BUNDLE"
)
if ca_bundle and path.exists(ca_bundle):
# if the user/application specifies cabundle.
self.read_cert_bundle(ca_bundle)
else:
import sys

# This import that depends on these libraries is to import certificates from them,
# we would like to have these as up to date as possible.
from requests import certs
if SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
# return if already loaded
return

if (
hasattr(certs, "__file__")
and path.exists(certs.__file__)
and path.exists(
path.join(path.dirname(certs.__file__), "cacert.pem")
)
):
# if cacert.pem exists next to certs.py in request
# package.
ca_bundle = path.join(
path.dirname(certs.__file__), "cacert.pem"
)
try:
ca_bundle = environ.get("REQUESTS_CA_BUNDLE") or environ.get(
"CURL_CA_BUNDLE"
)
if ca_bundle and path.exists(ca_bundle):
# if the user/application specifies cabundle.
self.read_cert_bundle(ca_bundle)
elif hasattr(sys, "_MEIPASS"):
# if pyinstaller includes cacert.pem
cabundle_candidates = [
["botocore", "vendored", "requests", "cacert.pem"],
["requests", "cacert.pem"],
["cacert.pem"],
]
for filename in cabundle_candidates:
ca_bundle = path.join(sys._MEIPASS, *filename)
if path.exists(ca_bundle):
self.read_cert_bundle(ca_bundle)
break
else:
logger.error("No cabundle file is found in _MEIPASS")
try:
import certifi

self.read_cert_bundle(certifi.where())
except Exception:
logger.debug("no certifi is installed. ignored.")

except Exception as e:
logger.error("Failed to read ca_bundle: %s", e)

if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
logger.error(
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
else:
import sys

# This import that depends on these libraries is to import certificates from them,
# we would like to have these as up to date as possible.
from requests import certs

if (
hasattr(certs, "__file__")
and path.exists(certs.__file__)
and path.exists(
path.join(path.dirname(certs.__file__), "cacert.pem")
)
):
# if cacert.pem exists next to certs.py in request
# package.
ca_bundle = path.join(
path.dirname(certs.__file__), "cacert.pem"
)
self.read_cert_bundle(ca_bundle)
elif hasattr(sys, "_MEIPASS"):
# if pyinstaller includes cacert.pem
cabundle_candidates = [
["botocore", "vendored", "requests", "cacert.pem"],
["requests", "cacert.pem"],
["cacert.pem"],
]
for filename in cabundle_candidates:
ca_bundle = path.join(sys._MEIPASS, *filename)
if path.exists(ca_bundle):
self.read_cert_bundle(ca_bundle)
break
else:
logger.error("No cabundle file is found in _MEIPASS")
try:
import certifi

self.read_cert_bundle(certifi.where())
except Exception:
logger.debug("no certifi is installed. ignored.")

except Exception as e:
logger.error("Failed to read ca_bundle: %s", e)

if not SnowflakeOCSP.ROOT_CERTIFICATES_DICT:
logger.error(
"No CA bundle file is found in the system. "
"Set REQUESTS_CA_BUNDLE to the file."
)
finally:
SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK.release()
else:
logger.info(
"Failed to acquire lock for ROOT_CERTIFICATES_DICT_LOCK. "
"Skipping reading CA bundle."
)
return

@staticmethod
def _calculate_tolerable_validity(this_update: float, next_update: float) -> int:
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/connector/ssl_wrap_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import certifi
import OpenSSL.SSL

from .constants import OCSPMode
from .constants import OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT, OCSPMode
from .errorcode import ER_OCSP_RESPONSE_CERT_STATUS_REVOKED
from .errors import OperationalError
from .session_manager import SessionManager
Expand All @@ -31,6 +31,9 @@

DEFAULT_OCSP_MODE: OCSPMode = OCSPMode.FAIL_OPEN
FEATURE_OCSP_MODE: OCSPMode = DEFAULT_OCSP_MODE
FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT: int = (
OCSP_ROOT_CERTS_DICT_LOCK_TIMEOUT_DEFAULT_NO_TIMEOUT
)

"""
OCSP Response cache file name
Expand Down Expand Up @@ -179,6 +182,7 @@ def ssl_wrap_socket_with_ocsp(*args: Any, **kwargs: Any) -> WrappedSocket:
ocsp_response_cache_uri=FEATURE_OCSP_RESPONSE_CACHE_FILE_NAME,
use_fail_open=FEATURE_OCSP_MODE == OCSPMode.FAIL_OPEN,
hostname=server_hostname,
root_certs_dict_lock_timeout=FEATURE_ROOT_CERTS_DICT_LOCK_TIMEOUT,
).validate(server_hostname, ret.connection)
if not v:
raise OperationalError(
Expand Down
83 changes: 83 additions & 0 deletions test/integ/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
import weakref
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4

import pytest
Expand Down Expand Up @@ -1585,6 +1586,88 @@ def test_ocsp_mode_insecure_mode_and_disable_ocsp_checks_mismatch_ocsp_enabled(
assert "snowflake.connector.ocsp_snowflake" not in caplog.text


@pytest.mark.skipolddriver
def test_root_certs_dict_lock_timeout_fail_open(conn_cnx):
"""Test OCSP root certificates lock timeout with fail-open mode and side effect mock."""

override_config = {
"ocsp_fail_open": True,
"ocsp_root_certs_dict_lock_timeout": 0.1,
}

with patch(
"snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK"
) as mock_lock:
snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {}

mock_lock.acquire = MagicMock(return_value=False)
mock_lock.release = MagicMock()

with conn_cnx(**override_config) as conn:
try:
with conn.cursor() as cur:
assert cur.execute("select 1").fetchall() == [(1,)]

if mock_lock.acquire.called:
mock_lock.acquire.assert_called_with(timeout=0.1)
assert conn._ocsp_root_certs_dict_lock_timeout == 0.1
finally:
conn.close()


@pytest.mark.skipolddriver
@pytest.mark.parametrize(
"ocsp_fail_open,timeout_value,expected_timeout",
[
(False, 1, 1), # fail-close mode with 1 second timeout
(True, 2, 2), # fail-open mode with 2 second timeout
],
)
def test_root_certs_dict_lock_timeout_with_property_mock(
conn_cnx, ocsp_fail_open, timeout_value, expected_timeout
):
"""Test OCSP root certificates lock timeout with property mock for different configurations."""
config = {
"ocsp_fail_open": ocsp_fail_open,
"ocsp_root_certs_dict_lock_timeout": timeout_value,
}

with patch(
"snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT_LOCK"
) as mock_lock:
snowflake.connector.ocsp_snowflake.SnowflakeOCSP.ROOT_CERTIFICATES_DICT = {}

type(mock_lock).acquire = PropertyMock(return_value=lambda timeout: False)
type(mock_lock).release = PropertyMock(return_value=lambda: None)

with conn_cnx(**config) as conn:
with conn.cursor() as cur:
assert cur.execute("select 1").fetchall() == [(1,)]

assert conn._ocsp_root_certs_dict_lock_timeout == expected_timeout
conn.close()


@pytest.mark.skipolddriver
@pytest.mark.parametrize(
"config,expected_timeout",
[
({"ocsp_fail_open": True, "ocsp_root_certs_dict_lock_timeout": 0.001}, 0.001),
({"ocsp_fail_open": True}, -1), # no timeout specified, should default to -1
],
)
def test_root_certs_dict_lock_timeout_basic_config(conn_cnx, config, expected_timeout):
"""Test OCSP root certificates lock timeout basic configuration without mocking."""
with conn_cnx(**config) as conn:
try:
with conn.cursor() as cur:
assert cur.execute("select 1").fetchall() == [(1,)]

assert conn._ocsp_root_certs_dict_lock_timeout == expected_timeout
finally:
conn.close()


@pytest.mark.skipolddriver
def test_ocsp_mode_insecure_mode_deprecation_warning(conn_cnx):
with warnings.catch_warnings(record=True) as w:
Expand Down
Loading
Loading