Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 120 additions & 7 deletions src/posit/connect/external/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,52 @@

import base64
import json
from datetime import datetime

from typing_extensions import TYPE_CHECKING, Dict
from typing_extensions import TYPE_CHECKING, Optional, TypedDict

from ..oauth.oauth import OAuthTokenType

if TYPE_CHECKING:
from ..client import Client


def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, str]:
class Credentials(TypedDict):
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str
expiration: datetime


def get_credentials(client: Client, user_session_token: str) -> Credentials:
"""
Get AWS credentials using OAuth token exchange.
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.

According to RFC 8693, the access token must be a base64 encoded JSON object
containing the AWS credentials. This function will decode and deserialize the
access token and return the AWS credentials.
According to RFC 8693, the access token must be a base64-encoded JSON object
containing the AWS credentials. This function will return the decoded and
deserialized AWS credentials.

Examples
--------
```python
from posit.connect import Client
from posit.connect.external.aws import get_aws_credentials
import boto3
from shiny.express import session

client = Client()
session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
credentials = get_aws_credentials(client, user_session_token)
aws_session_expiration = credentials["expiration"]
aws_session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
aws_session_token=credentials["aws_session_token"],
)

s3 = aws_session.resource("s3")
bucket = s3.Bucket("your-bucket-name")
```

Parameters
----------
Expand All @@ -42,8 +72,91 @@ def get_aws_credentials(client: Client, user_session_token: str) -> Dict[str, st
access_token = credentials.get("access_token")
if not access_token:
raise ValueError("No access token found in credentials")
return _decode_access_token(access_token)


def get_content_credentials(
client: Client, content_session_token: Optional[str] = None
) -> Credentials:
"""
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.

According to RFC 8693, the access token must be a base64-encoded JSON object
containing the AWS credentials. This function will return the decoded and
deserialized AWS credentials.

Examples
--------
```python
from posit.connect import Client
from posit.connect.external.aws import get_aws_content_credentials
import boto3

client = Client()
credentials = get_aws_content_credentials(client)
session_expiration = credentials["expiration"]
aws_session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
aws_session_token=credentials["aws_session_token"],
)

s3 = session.resource("s3")
bucket = s3.Bucket("your-bucket-name")
```

Parameters
----------
client : Client
The client to use for making requests
content_session_token : str
The content session token to exchange

Returns
-------
Dict[str, str]
Dictionary containing AWS credentials with keys:
access_key_id, secret_access_key, session_token, and expiration
"""
# Get credentials using OAuth
credentials = client.oauth.get_content_credentials(
content_session_token=content_session_token,
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
)

# Decode base64 access token
access_token = credentials.get("access_token")
if not access_token:
raise ValueError("No access token found in credentials")
return _decode_access_token(access_token)


def _decode_access_token(access_token: str) -> Credentials:
"""
Decode and deserialize an access token containing AWS credentials.

According to RFC 8693, the access token must be a base64-encoded JSON object
containing the AWS credentials. This function will decode and deserialize the
access token and return the AWS credentials.

Parameters
----------
access_token : str
The access token to decode

Returns
-------
Credentials
Dictionary containing AWS credentials with keys:
access_key_id, secret_access_key, session_token, and expiration
"""
decoded_bytes = base64.b64decode(access_token)
decoded_str = decoded_bytes.decode("utf-8")
aws_credentials = json.loads(decoded_str)

return aws_credentials
return Credentials(
aws_access_key_id=aws_credentials["accessKeyId"],
aws_secret_access_key=aws_credentials["secretAccessKey"],
aws_session_token=aws_credentials["sessionToken"],
expiration=datetime.strptime(aws_credentials["expiration"], "%Y-%m-%dT%H:%M:%SZ"),
)
8 changes: 7 additions & 1 deletion src/posit/connect/oauth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ def get_credentials(
response = self._ctx.client.post(self._path, data=data)
return Credentials(**response.json())

def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
def get_content_credentials(
self,
content_session_token: Optional[str] = None,
requested_token_type: Optional[str | OAuthTokenType] = None,
) -> Credentials:
"""Perform an oauth credential exchange with a content-session-token."""
# craft a credential exchange request
data = {}
data["grant_type"] = GRANT_TYPE
data["subject_token_type"] = OAuthTokenType.CONTENT_SESSION_TOKEN
data["subject_token"] = content_session_token or _get_content_session_token()
if requested_token_type:
data["requested_token_type"] = requested_token_type

response = self._ctx.client.post(self._path, data=data)
return Credentials(**response.json())
Expand Down
81 changes: 72 additions & 9 deletions tests/posit/connect/external/test_aws.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from datetime import datetime

import pytest
import responses

from posit.connect import Client
from posit.connect.external.aws import get_aws_credentials
from posit.connect.external.aws import (
Credentials,
_decode_access_token,
get_content_credentials,
get_credentials,
)

aws_creds = {
"accessKeyId": "abc123",
"secretAccessKey": "def456",
"sessionToken": "ghi789",
"expiration": "2025-01-01T00:00:00Z",
}
aws_creds = Credentials(
aws_access_key_id="abc123",
aws_secret_access_key="def456",
aws_session_token="ghi789",
expiration=datetime(2025, 1, 1, 0, 0, 0, 0),
)

encoded_aws_creds = "eyJhY2Nlc3NLZXlJZCI6ICJhYmMxMjMiLCAic2VjcmV0QWNjZXNzS2V5IjogImRlZjQ1NiIsICJzZXNzaW9uVG9rZW4iOiAiZ2hpNzg5IiwgImV4cGlyYXRpb24iOiAiMjAyNS0wMS0wMVQwMDowMDowMFoifQ=="

Expand Down Expand Up @@ -38,7 +45,7 @@ def test_get_aws_credentials(self):

c = Client(api_key="12345", url="https://connect.example/")
c._ctx.version = None
response = get_aws_credentials(c, "cit")
response = get_credentials(c, "cit")

assert response == aws_creds

Expand All @@ -63,6 +70,62 @@ def test_get_aws_credentials_no_token(self):
c._ctx.version = None

with pytest.raises(ValueError) as e:
get_aws_credentials(c, "cit")
get_credentials(c, "cit")

assert e.match("No access token found in credentials")

@responses.activate
def test_get_aws_content_credentials(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:content-session-token",
"subject_token": "cit",
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
}
)
],
json={
"access_token": encoded_aws_creds,
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
"token_type": "aws_credentials",
},
)

c = Client(api_key="12345", url="https://connect.example/")
c._ctx.version = None
response = get_content_credentials(c, "cit")

assert response == aws_creds

@responses.activate
def test_get_aws_content_credentials_no_token(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:content-session-token",
"subject_token": "cit",
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
}
)
],
json={},
)

c = Client(api_key="12345", url="https://connect.example/")
c._ctx.version = None

with pytest.raises(ValueError) as e:
get_content_credentials(c, "cit")

assert e.match("No access token found in credentials")

def test_decode_access_token(self):
decoded_creds = _decode_access_token(encoded_aws_creds)
assert decoded_creds == aws_creds
36 changes: 19 additions & 17 deletions tests/posit/connect/oauth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_get_credentials(self):
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:user-session-token",
"subject_token": "cit",
# no requested token type set
},
),
],
Expand All @@ -41,7 +42,7 @@ def test_get_credentials(self):
assert creds.get("access_token") == "viewer-token"

@responses.activate
def test_get_credentials_api_key(self):
def test_get_credentials_with_requested_token_type(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
Expand All @@ -68,34 +69,32 @@ def test_get_credentials_api_key(self):
assert creds.get("token_type") == "Key"

@responses.activate
def test_get_credentials_aws(self):
def test_get_content_credentials(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:user-session-token",
"subject_token_type": "urn:posit:connect:content-session-token",
"subject_token": "cit",
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
# no requested token type set
},
),
],
json={
"access_token": "encoded-aws-creds",
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
"token_type": "aws_credentials",
"access_token": "content-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
},
)
c = Client(api_key="12345", url="https://connect.example/")
c._ctx.version = None
creds = c.oauth.get_credentials("cit", OAuthTokenType.AWS_CREDENTIALS)
assert creds.get("access_token") == "encoded-aws-creds"
assert creds.get("issued_token_type") == "urn:ietf:params:aws:token-type:credentials"
assert creds.get("token_type") == "aws_credentials"
creds = c.oauth.get_content_credentials("cit")
assert creds.get("access_token") == "content-token"

@responses.activate
def test_get_content_credentials(self):
def test_get_content_credentials_with_requested_token_type(self):
responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
Expand All @@ -104,19 +103,22 @@ def test_get_content_credentials(self):
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:content-session-token",
"subject_token": "cit",
"requested_token_type": "urn:ietf:params:aws:token-type:credentials",
},
),
],
json={
"access_token": "content-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
"access_token": "encoded-aws-creds",
"issued_token_type": "urn:ietf:params:aws:token-type:credentials",
"token_type": "aws_credentials",
},
)
c = Client(api_key="12345", url="https://connect.example/")
c._ctx.version = None
creds = c.oauth.get_content_credentials("cit")
assert creds.get("access_token") == "content-token"
creds = c.oauth.get_content_credentials("cit", OAuthTokenType.AWS_CREDENTIALS)
assert creds.get("access_token") == "encoded-aws-creds"
assert creds.get("issued_token_type") == "urn:ietf:params:aws:token-type:credentials"
assert creds.get("token_type") == "aws_credentials"

@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
@responses.activate
Expand Down