Skip to content

Support Snowflake oauth tokens #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for the `--draft` option when deploying content,
this allows to deploy a new bundle for the content without exposing
it as a the activated one.
- Improved support for Posit Connect deployments
hosted in Snowpark Container Services.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion docs/overrides/partials/header.html
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
{% endif %}
<div class="md-flex__cell md-flex__cell--shrink left-nav">
<ul class="md-tabs__list">
<li class="md-tabs__item"><a href="{{ base_url }}/changelog/" title="Release Notes" class="md-tabs__link md-source">Release Notes</a></li>
<li class="md-tabs__item"><a href="{{ base_url }}/CHANGELOG/" title="Release Notes" class="md-tabs__link md-source">Release Notes</a></li>
<li class="md-tabs__item"><a href="https://support.posit.co/hc/en-us" title="Posit Support" class="md-tabs__link md-source">Help</a></li>
</ul>
</div>
Expand Down
76 changes: 56 additions & 20 deletions rsconnect/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
TaskStatusV1,
UserRecord,
)
from .snowflake import generate_jwt, get_connection_parameters
from .snowflake import generate_jwt, get_parameters
from .timeouts import get_task_timeout, get_task_timeout_help_message

if TYPE_CHECKING:
Expand Down Expand Up @@ -260,40 +260,62 @@ def __init__(
self.bootstrap_jwt = None

def token_endpoint(self) -> str:
params = get_connection_parameters(self.snowflake_connection_name)
params = get_parameters(self.snowflake_connection_name)

if params is None:
raise RSConnectException("No Snowflake connection found.")

return "https://{}.snowflakecomputing.com/".format(params["account"])

def fmt_payload(self) -> str:
params = get_connection_parameters(self.snowflake_connection_name)
def fmt_payload(self):
params = get_parameters(self.snowflake_connection_name)

if params is None:
raise RSConnectException("No Snowflake connection found.")

spcs_url = urlparse(self.url)
scope = "session:role:{} {}".format(params["role"], spcs_url.netloc)
jwt = generate_jwt(self.snowflake_connection_name)
grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer"

payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type}
payload = urlencode(payload)
return payload
authenticator = params.get("authenticator")
if authenticator == "SNOWFLAKE_JWT":
spcs_url = urlparse(self.url)
scope = (
"session:role:{} {}".format(params["role"], spcs_url.netloc) if params.get("role") else spcs_url.netloc
)
jwt = generate_jwt(self.snowflake_connection_name)
grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer"

payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type}
payload = urlencode(payload)
return {
"body": payload,
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
"path": "/oauth/token",
}
elif authenticator == "oauth":
payload = {
"data": {
"AUTHENTICATOR": "OAUTH",
"TOKEN": params["token"],
}
}
return {
"body": payload,
"headers": {
"Content-Type": "application/json",
"Authorization": "Bearer %s" % params["token"],
"X-Snowflake-Authorization-Token-Type": "OAUTH",
},
"path": "/session/v1/login-request",
}
else:
raise NotImplementedError("Unsupported authenticator for SPCS Connect: %s" % authenticator)

def exchange_token(self) -> str:
try:
server = HTTPServer(url=self.token_endpoint())
payload = self.fmt_payload()

response = server.request(
method="POST",
path="/oauth/token",
body=payload,
headers={"Content-Type": "application/x-www-form-urlencoded"},
method="POST", **payload # type: ignore[arg-type] # fmt_payload returns a dict with body and headers
)

response = cast(HTTPResponse, response)

# borrowed from AbstractRemoteServer.handle_bad_response
Expand All @@ -313,10 +335,24 @@ def exchange_token(self) -> str:
if not response.response_body:
raise RSConnectException("Token exchange returned empty response")

# Ensure we return a string
# Ensure response body is decoded to string on the object
if isinstance(response.response_body, bytes):
return response.response_body.decode("utf-8")
return response.response_body
response.response_body = response.response_body.decode("utf-8")

# Try to parse as JSON first
try:
import json

json_data = json.loads(response.response_body)
# If it's JSON, extract the token from data.token
if isinstance(json_data, dict) and "data" in json_data and "token" in json_data["data"]:
return json_data["data"]["token"]
else:
# JSON format doesn't match expected structure, return raw response
return response.response_body
except (json.JSONDecodeError, ValueError):
# Not JSON, return the raw response body
return response.response_body

except RSConnectException as e:
raise RSConnectException(f"Failed to exchange Snowflake token: {str(e)}") from e
Expand Down
44 changes: 30 additions & 14 deletions rsconnect/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,43 @@ def list_connections() -> List[Dict[str, Any]]:
raise RSConnectException("Could not list snowflake connections.")


def get_connection_parameters(name: Optional[str] = None) -> Optional[Dict[str, Any]]:
def get_parameters(name: Optional[str] = None) -> Dict[str, Any]:
"""Get Snowflake connection parameters.
Args:
name: The name of the connection to retrieve. If None, returns the default connection.

Returns:
A dictionary of connection parameters.
"""
try:
from snowflake.connector.config_manager import CONFIG_MANAGER
except ImportError:
raise RSConnectException("snowflake-cli is not installed.")
try:
connections = CONFIG_MANAGER["connections"]
if not isinstance(connections, dict):
raise TypeError("connections is not a dictionary")

if name is None:
def_connection_name = CONFIG_MANAGER["default_connection_name"]
if not isinstance(def_connection_name, str):
raise TypeError("default_connection_name is not a string")
params = connections[def_connection_name]
else:
params = connections[name]

connection_list = list_connections()
# return parameters for default connection if configured
# otherwise return named connection
if not isinstance(params, dict):
raise TypeError("connection parameters is not a dictionary")

if not connection_list:
raise RSConnectException("No Snowflake connections found.")
return {str(k): v for k, v in params.items()}

try:
if not name:
return next((x["parameters"] for x in connection_list if x.get("is_default")), None)
else:
return next((x["parameters"] for x in connection_list if x.get("connection_name") == name))
except StopIteration:
raise RSConnectException(f"No Snowflake connection found with name '{name}'.")
except (KeyError, AttributeError) as e:
raise RSConnectException(f"Could not get Snowflake connection: {e}")


def generate_jwt(name: Optional[str] = None) -> str:

_ = get_connection_parameters(name)
_ = get_parameters(name)
connection_name = "" if name is None else name

try:
Expand Down
63 changes: 41 additions & 22 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,41 +526,48 @@ def test_token_endpoint(self, mock_token_endpoint):
endpoint = server.token_endpoint()
assert endpoint == "https://example.snowflakecomputing.com/"

@patch("rsconnect.api.get_connection_parameters")
def test_token_endpoint_with_account(self, mock_get_connection_parameters):
@patch("rsconnect.api.get_parameters")
def test_token_endpoint_with_account(self, mock_get_parameters):
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
mock_get_connection_parameters.return_value = {"account": "test_account"}
mock_get_parameters.return_value = {"account": "test_account"}
endpoint = server.token_endpoint()
assert endpoint == "https://test_account.snowflakecomputing.com/"
mock_get_connection_parameters.assert_called_once_with("example_connection")
mock_get_parameters.assert_called_once_with("example_connection")

@patch("rsconnect.api.get_connection_parameters")
def test_token_endpoint_with_none_params(self, mock_get_connection_parameters):
@patch("rsconnect.api.get_parameters")
def test_token_endpoint_with_none_params(self, mock_get_parameters):
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
mock_get_connection_parameters.return_value = None
mock_get_parameters.return_value = None
with pytest.raises(RSConnectException, match="No Snowflake connection found."):
server.token_endpoint()

@patch("rsconnect.api.get_connection_parameters")
def test_fmt_payload(self, mock_get_connection_parameters):
@patch("rsconnect.api.get_parameters")
def test_fmt_payload(self, mock_get_parameters):
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
mock_get_connection_parameters.return_value = {"account": "test_account", "role": "test_role"}
mock_get_parameters.return_value = {
"account": "test_account",
"role": "test_role",
"authenticator": "SNOWFLAKE_JWT",
}

with patch("rsconnect.api.generate_jwt") as mock_generate_jwt:
mock_generate_jwt.return_value = "mocked_jwt"
payload = server.fmt_payload()

assert "scope=session%3Arole%3Atest_role+spcs.example.com" in payload
assert "assertion=mocked_jwt" in payload
assert "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" in payload
assert (
payload["body"]
== "scope=session%3Arole%3Atest_role+spcs.example.com&assertion=mocked_jwt&grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" # noqa
)
assert payload["headers"] == {"Content-Type": "application/x-www-form-urlencoded"}
assert payload["path"] == "/oauth/token"

mock_get_connection_parameters.assert_called_once_with("example_connection")
mock_get_parameters.assert_called_once_with("example_connection")
mock_generate_jwt.assert_called_once_with("example_connection")

@patch("rsconnect.api.get_connection_parameters")
def test_fmt_payload_with_none_params(self, mock_get_connection_parameters):
@patch("rsconnect.api.get_parameters")
def test_fmt_payload_with_none_params(self, mock_get_parameters):
server = SPCSConnectServer("https://spcs.example.com", "example_connection")
mock_get_connection_parameters.return_value = None
mock_get_parameters.return_value = None
with pytest.raises(RSConnectException, match="No Snowflake connection found."):
server.fmt_payload()

Expand All @@ -579,7 +586,11 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc

# Mock the token endpoint and payload
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
mock_fmt_payload.return_value = "mocked_payload"
mock_fmt_payload.return_value = {
"body": "mocked_payload_body",
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
"path": "/oauth/token",
}

# Call the method
result = server.exchange_token()
Expand All @@ -589,9 +600,9 @@ def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, moc
mock_http_server.assert_called_once_with(url="https://example.snowflakecomputing.com/")
mock_server_instance.request.assert_called_once_with(
method="POST",
path="/oauth/token",
body="mocked_payload",
body="mocked_payload_body",
headers={"Content-Type": "application/x-www-form-urlencoded"},
path="/oauth/token",
)

@patch("rsconnect.api.HTTPServer")
Expand All @@ -610,7 +621,11 @@ def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint

# Mock the token endpoint and payload
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
mock_fmt_payload.return_value = "mocked_payload"
mock_fmt_payload.return_value = {
"body": "mocked_payload_body",
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
"path": "/oauth/token",
}

# Call the method and verify it raises the expected exception
with pytest.raises(RSConnectException, match="Failed to exchange Snowflake token"):
Expand All @@ -631,7 +646,11 @@ def test_exchange_token_empty_response(self, mock_fmt_payload, mock_token_endpoi

# Mock the token endpoint and payload
mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/"
mock_fmt_payload.return_value = "mocked_payload"
mock_fmt_payload.return_value = {
"body": "mocked_payload_body",
"headers": {"Content-Type": "application/x-www-form-urlencoded"},
"path": "/oauth/token",
}

# Call the method and verify it raises the expected exception
with pytest.raises(
Expand Down
Loading
Loading