Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
BS3TextFieldWidget,
)
from flask_babel import lazy_gettext
from wtforms import BooleanField, PasswordField, StringField
from wtforms import BooleanField, IntegerField, PasswordField, StringField

return {
"account": StringField(lazy_gettext("Account"), widget=BS3TextFieldWidget()),
Expand All @@ -130,6 +130,10 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
"insecure_mode": BooleanField(
label=lazy_gettext("Insecure mode"), description="Turns off OCSP certificate checks"
),
"proxy_host": StringField(lazy_gettext("Proxy Host"), widget=BS3TextFieldWidget()),
"proxy_port": IntegerField(lazy_gettext("Proxy Port")),
"proxy_user": StringField(lazy_gettext("Proxy User"), widget=BS3TextFieldWidget()),
"proxy_password": PasswordField(lazy_gettext("Proxy Password"), widget=BS3PasswordFieldWidget()),
}

@classmethod
Expand All @@ -152,6 +156,10 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"token_endpoint": "token endpoint",
"refresh_token": "refresh token",
"scope": "scope",
"proxy_host": "proxy.example.com",
"proxy_port": "8080",
"proxy_user": "proxy_username",
"proxy_password": "proxy_password",
},
indent=1,
),
Expand All @@ -166,6 +174,10 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"private_key_file": "Path of snowflake private key (PEM Format)",
"private_key_content": "Content to snowflake private key (PEM format)",
"insecure_mode": "insecure mode",
"proxy_host": "Proxy server hostname",
"proxy_port": "Proxy server port",
"proxy_user": "Proxy username (optional)",
"proxy_password": "Proxy password (optional)",
},
}

Expand Down Expand Up @@ -431,6 +443,21 @@ def _get_static_conn_params(self) -> dict[str, str | None]:
if ocsp_fail_open is not None:
conn_config["ocsp_fail_open"] = _try_to_boolean(ocsp_fail_open)

# Add proxy configuration if specified
proxy_host = self._get_field(extra_dict, "proxy_host")
proxy_port = self._get_field(extra_dict, "proxy_port")
proxy_user = self._get_field(extra_dict, "proxy_user")
proxy_password = self._get_field(extra_dict, "proxy_password")

if proxy_host:
conn_config["proxy_host"] = proxy_host
if proxy_port:
conn_config["proxy_port"] = int(proxy_port) if isinstance(proxy_port, str) else proxy_port
if proxy_user:
conn_config["proxy_user"] = proxy_user
if proxy_password:
conn_config["proxy_password"] = proxy_password

return conn_config

def _get_valid_oauth_token(
Expand Down Expand Up @@ -524,6 +551,10 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
"client_store_temporary_credential",
"json_result_force_utf8_decoding",
"ocsp_fail_open",
"proxy_host",
"proxy_port",
"proxy_user",
"proxy_password",
]
}
)
Expand Down
131 changes: 131 additions & 0 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,3 +1309,134 @@ def test_oauth_token_refresh_after_expiry(self, mock_timezone_utcnow, mock_reque

# Ensure refresh actually happened
assert mock_requests_post.call_count == 2

def test_get_conn_params_with_proxy_host_only(self):
"""Test proxy configuration with only host specified."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params()

assert conn_params["proxy_host"] == "proxy.example.com"
assert "proxy_port" not in conn_params
assert "proxy_user" not in conn_params
assert "proxy_password" not in conn_params

def test_get_conn_params_with_proxy_host_and_port(self):
"""Test proxy configuration with host and port."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = "8080"

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params()

assert conn_params["proxy_host"] == "proxy.example.com"
assert conn_params["proxy_port"] == 8080
assert "proxy_user" not in conn_params
assert "proxy_password" not in conn_params

def test_get_conn_params_with_proxy_port_as_int(self):
"""Test proxy configuration with port as integer."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = 8080 # Integer instead of string

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params()

assert conn_params["proxy_host"] == "proxy.example.com"
assert conn_params["proxy_port"] == 8080
assert isinstance(conn_params["proxy_port"], int)

def test_get_conn_params_with_proxy_full_config(self):
"""Test proxy configuration with all parameters."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = "8080"
connection_kwargs["extra"]["proxy_user"] = "proxy_username"
connection_kwargs["extra"]["proxy_password"] = "proxy_password"

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params()

assert conn_params["proxy_host"] == "proxy.example.com"
assert conn_params["proxy_port"] == 8080
assert conn_params["proxy_user"] == "proxy_username"
assert conn_params["proxy_password"] == "proxy_password"

def test_get_conn_params_with_proxy_backcompat_prefix(self):
"""Test proxy configuration with backcompat prefix."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["extra__snowflake__proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["extra__snowflake__proxy_port"] = "8080"
connection_kwargs["extra"]["extra__snowflake__proxy_user"] = "proxy_username"
connection_kwargs["extra"]["extra__snowflake__proxy_password"] = "proxy_password"

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
conn_params = hook._get_conn_params()

assert conn_params["proxy_host"] == "proxy.example.com"
assert conn_params["proxy_port"] == 8080
assert conn_params["proxy_user"] == "proxy_username"
assert conn_params["proxy_password"] == "proxy_password"

def test_get_conn_with_proxy_should_call_connect(self):
"""Test that proxy parameters are passed to connector.connect()."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = "8080"
connection_kwargs["extra"]["proxy_user"] = "proxy_user"
connection_kwargs["extra"]["proxy_password"] = "proxy_pass"

with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector,
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
hook.get_conn()

call_args = mock_connector.connect.call_args[1]
assert call_args["proxy_host"] == "proxy.example.com"
assert call_args["proxy_port"] == 8080
assert call_args["proxy_user"] == "proxy_user"
assert call_args["proxy_password"] == "proxy_pass"

def test_sqlalchemy_uri_excludes_proxy_params(self):
"""Test that proxy parameters are excluded from SQLAlchemy URI."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = "8080"

with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
uri = hook.get_uri()

# Proxy parameters should NOT appear in the URI
assert "proxy_host" not in uri
assert "proxy_port" not in uri
assert "proxy.example.com" not in uri
assert "8080" not in uri

def test_get_sqlalchemy_engine_with_proxy(self):
"""Test get_sqlalchemy_engine does not include proxy params in URI but passes to connect_args if needed."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["proxy_host"] = "proxy.example.com"
connection_kwargs["extra"]["proxy_port"] = "8080"

with (
mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()),
mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine,
):
hook = SnowflakeHook(snowflake_conn_id="test_conn")
hook.get_sqlalchemy_engine()

# Check that the URI doesn't contain proxy params
called_uri = mock_create_engine.call_args[0][0]
assert "proxy_host" not in str(called_uri)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def to_dict(self) -> dict[str, Any]: ...
"password",
"private_key",
"proxy",
"proxy_password",
"proxies",
"secret",
"token",
Expand Down