Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

from requests import Response, Session
from typing_extensions import TYPE_CHECKING, overload

from . import hooks, me
Expand All @@ -14,13 +13,16 @@
from .metrics.metrics import Metrics
from .oauth.oauth import API_KEY_TOKEN_TYPE, OAuth
from .resources import _PaginatedResourceSequence, _ResourceSequence
from .sessions import Session
from .system import System
from .tags import Tags
from .tasks import Tasks
from .users import User, Users
from .vanities import Vanities

if TYPE_CHECKING:
from requests import Response

from .environments import Environments
from .packages import Packages

Expand Down Expand Up @@ -208,6 +210,7 @@ def with_user_session_token(self, token: str) -> Client:
--------
```python
from posit.connect import Client

client = Client().with_user_session_token("my-user-session-token")
```

Expand All @@ -218,13 +221,14 @@ def with_user_session_token(self, token: str) -> Client:

client = Client()


@reactive.calc
def visitor_client():
## read the user session token and generate a new client
user_session_token = session.http_conn.headers.get(
"Posit-Connect-User-Session-Token"
)
user_session_token = session.http_conn.headers.get("Posit-Connect-User-Session-Token")
return client.with_user_session_token(user_session_token)


@render.text
def user_profile():
# fetch the viewer's profile information
Expand Down
103 changes: 103 additions & 0 deletions src/posit/connect/sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from urllib.parse import urljoin

import requests


class Session(requests.Session):
"""Custom session that implements CURLOPT_POSTREDIR.

This class mimics the functionality of CURLOPT_POSTREDIR from libcurl by
providing a custom implementation of the POST method. It allows the caller
to control whether the original POST data is preserved on redirects or if the
request should be converted to a GET when a redirect occurs. This is achieved
by disabling automatic redirect handling and manually following the redirect
chain with the desired behavior.

Notes
-----
The custom `post` method in this class:

- Disables automatic redirect handling by setting ``allow_redirects=False``.
- Manually follows redirects up to a specified ``max_redirects``.
- Determines the HTTP method for subsequent requests based on the response
status code and the ``preserve_post`` flag:

- For HTTP status codes 307 and 308, the method and request body are
always preserved as POST.
- For other redirects (e.g., 301, 302, 303), the behavior is determined
by ``preserve_post``:
- If ``preserve_post=True``, the POST method is maintained.
- If ``preserve_post=False``, the method is converted to GET and the
request body is discarded.

Examples
--------
Create a session and send a POST request while preserving POST data on redirects:

>>> session = Session()
>>> response = session.post(
... "https://example.com/api", data={"key": "value"}, preserve_post=True
... )
>>> print(response.status_code)

See Also
--------
requests.Session : The base session class from the requests library.
"""

def post(self, url, data=None, json=None, preserve_post=True, max_redirects=5, **kwargs):
"""
Send a POST request and handle redirects manually.

Parameters
----------
url : str
The URL to send the POST request to.
data : dict, bytes, or file-like object, optional
The form data to send.
json : any, optional
The JSON data to send.
preserve_post : bool, optional
If True, re-send POST data on redirects (mimicking CURLOPT_POSTREDIR);
if False, converts to GET on 301/302/303 responses.
max_redirects : int, optional
Maximum number of redirects to follow.
**kwargs
Additional keyword arguments passed to the request.

Returns
-------
requests.Response
The final response after following redirects.
"""
# Force manual redirect handling by disabling auto redirects.
kwargs["allow_redirects"] = False

# Initial POST request
response = super().post(url, data=data, json=json, **kwargs)
redirect_count = 0

# Manually follow redirects, if any
while response.is_redirect and redirect_count < max_redirects:
redirect_url = response.headers.get("location")
if not redirect_url:
break # No redirect URL; exit loop

redirect_url = urljoin(response.url, redirect_url)

# For 307 and 308 the HTTP spec mandates preserving the method and body.
if response.status_code in (307, 308):
method = "POST"
else:
if preserve_post:
method = "POST"
else:
method = "GET"
data = None
json = None

# Perform the next request in the redirect chain.
response = self.request(method, redirect_url, data=data, json=json, **kwargs)
redirect_count += 1

return response
134 changes: 134 additions & 0 deletions tests/posit/connect/test_sessions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import pytest
import responses

from posit.connect.sessions import Session


@responses.activate
def test_post_no_redirect():
url = "https://connect.example.com/api"
responses.add(responses.POST, url, json={"result": "ok"}, status=200)

session = Session()
response = session.post(url, data={"key": "value"})

assert response.status_code == 200
assert len(responses.calls) == 1
# Confirm that the request method was POST.
assert responses.calls[0].request.method == "POST"


@responses.activate
def test_post_with_redirect_preserve():
initial_url = "http://connect.example.com/api"
redirect_url = "http://connect.example.com/redirect"

responses.add(responses.POST, initial_url, status=302, headers={"location": "/redirect"})
responses.add(responses.POST, redirect_url, json={"result": "redirected"}, status=200)

session = Session()
response = session.post(initial_url, data={"key": "value"}, preserve_post=True)

assert response.status_code == 200
assert len(responses.calls) == 2

# Both calls should use the POST method.
assert responses.calls[0].request.method == "POST"
assert responses.calls[1].request.method == "POST"


@responses.activate
def test_post_with_redirect_no_preserve():
initial_url = "http://connect.example.com/api"
redirect_url = "http://connect.example.com/redirect"

responses.add(responses.POST, initial_url, status=302, headers={"location": "/redirect"})
responses.add(responses.GET, redirect_url, json={"result": "redirected"}, status=200)

session = Session()
response = session.post(initial_url, data={"key": "value"}, preserve_post=False)

assert response.status_code == 200
assert len(responses.calls) == 2
# The initial call is a POST, but the follow-up should be a GET since preserve_post is False
assert responses.calls[0].request.method == "POST"
assert responses.calls[1].request.method == "GET"


@pytest.mark.parametrize("status_code", [307, 308])
@responses.activate
def test_post_redirect_307_308(status_code):
initial_url = "http://connect.example.com/api"
redirect_url = "http://connect.example.com/redirect"

# For 307 and 308 redirects, the HTTP spec mandates preserving the method.
responses.add(
responses.POST, initial_url, status=status_code, headers={"location": "/redirect"}
)
responses.add(responses.POST, redirect_url, json={"result": "redirected"}, status=200)

session = Session()
# Even with preserve_post=False, a 307 or 308 redirect should use POST.
response = session.post(initial_url, data={"key": "value"}, preserve_post=False)

assert response.status_code == 200
assert len(responses.calls) == 2
# Confirm that the method for the redirect is still POST.
assert responses.calls[1].request.method == "POST"


@responses.activate
def test_post_redirect_max_redirects():
initial_url = "http://connect.example.com/api"
redirect1_url = "http://connect.example.com/redirect1"
redirect2_url = "http://connect.example.com/redirect2"

# Build a chain of 3 redirects.
responses.add(responses.POST, initial_url, status=302, headers={"location": "/redirect1"})
responses.add(responses.POST, redirect1_url, status=302, headers={"location": "/redirect2"})
responses.add(responses.POST, redirect2_url, status=302, headers={"location": "/redirect3"})

session = Session()
# Limit to 2 redirects; thus, the third redirect response should not be followed.
response = session.post(
initial_url, data={"key": "value"}, max_redirects=2, preserve_post=True
)

# The calls should include: initial, first redirect, and second redirect.
assert len(responses.calls) == 3
# The final response is the one from the second redirect.
assert response.status_code == 302
# The Location header should point to the third URL.
assert response.headers.get("location") == "/redirect3"


@responses.activate
def test_post_redirect_no_location():
url = "http://connect.example.com/api"
# Simulate a redirect response that lacks a Location header.
responses.add(responses.POST, url, status=302, headers={})

session = Session()
response = session.post(url, data={"key": "value"})

# The loop should break immediately since there is no location to follow.
assert len(responses.calls) == 1
assert response.status_code == 302


@responses.activate
def test_post_redirect_location_none_explicit():
url = "http://connect.example.com/api"

# Use a callback to explicitly return a None for the "location" header.
def request_callback(request):
return (302, {"location": ""}, "Redirect without location")

responses.add_callback(responses.POST, url, callback=request_callback)

session = Session()
response = session.post(url, data={"key": "value"})

# The redirect loop should break since location is None.
assert len(responses.calls) == 1
assert response.status_code == 302