Skip to content

Commit

Permalink
Small refactor of AWS Signer classes for both sync and async clients (#…
Browse files Browse the repository at this point in the history
…866)

* made custom headers be available to async aws signer

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* updated changelog

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* added tests for using host header for AWS request signature on both sync and async clients

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* added documentation guide about aws auth when accessing via tunnel

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* small refactor of AWS Signer classes on sync and async clients; improved testing on them as well

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* changelog

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* fixed test

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

* lint fix

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>

---------

Signed-off-by: Bruno Murino <brunomurino@users.noreply.github.com>
  • Loading branch information
brunomurino authored Dec 3, 2024
1 parent 87aebcd commit 7815c6a
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 124 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Added sync and async sample that uses `search_after` parameter ([859](https://github.com/opensearch-project/opensearch-py/pull/859))
### Updated APIs
### Changed
- Small refactor of AWS Signer classes for both sync and async clients ([866](https://github.com/opensearch-project/opensearch-py/pull/866))
### Deprecated
### Removed
### Fixed
Expand Down
77 changes: 9 additions & 68 deletions opensearchpy/helpers/asyncsigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# GitHub history for details.

from typing import Any, Dict, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse

from opensearchpy.helpers.signer import AWSV4Signer


class AWSV4SignerAsyncAuth:
Expand All @@ -17,33 +18,21 @@ class AWSV4SignerAsyncAuth:
"""

def __init__(self, credentials: Any, region: str, service: str = "es") -> None:
if not credentials:
raise ValueError("Credentials cannot be empty")
self.credentials = credentials

if not region:
raise ValueError("Region cannot be empty")
self.region = region

if not service:
raise ValueError("Service name cannot be empty")
self.service = service
self.signer = AWSV4Signer(credentials, region, service)

def __call__(
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
return self._sign_request(method, url, query_string, body, headers)
return self._sign_request(method=method, url=url, body=body, headers=headers)

def _sign_request(
self,
method: str,
url: str,
query_string: Optional[str],
body: Optional[Union[str, bytes]],
headers: Optional[Dict[str, str]],
) -> Dict[str, str]:
Expand All @@ -53,58 +42,10 @@ def _sign_request(
:return: signed headers
"""

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

signature_host = self._fetch_url(url, headers or dict())

# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(
updated_headers = self.signer.sign(
method=method,
url=signature_host,
data=body,
)

# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
# creating a race condition if the credentials expire before secret_key
# is called but after access_key- the end result is the access_key doesn't
# correspond to the secret_key used to sign the request. To avoid this,
# get_frozen_credentials() which returns non-refreshing credentials is
# called if it exists.
credentials = (
self.credentials.get_frozen_credentials()
if hasattr(self.credentials, "get_frozen_credentials")
and callable(self.credentials.get_frozen_credentials)
else self.credentials
url=url,
body=body,
headers=headers,
)

sig_v4_auth = SigV4Auth(credentials, self.service, self.region)
sig_v4_auth.add_auth(aws_request)
aws_request.headers["X-Amz-Content-SHA256"] = sig_v4_auth.payload(aws_request)

# copy the headers from AWS request object into the prepared_request
return dict(aws_request.headers.items())

def _fetch_url(self, url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"

# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)

# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc

# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring
return updated_headers
78 changes: 42 additions & 36 deletions opensearchpy/helpers/signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.

from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
from urllib.parse import parse_qs, urlencode, urlparse

import requests
Expand All @@ -31,7 +31,9 @@ def __init__(self, credentials, region: str, service: str = "es") -> Any: # typ
raise ValueError("Service name cannot be empty")
self.service = service

def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
def sign(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
"""
This method signs the request and returns headers.
:param method: HTTP method
Expand All @@ -43,8 +45,10 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

signature_host = self._fetch_url(url, headers or dict())

# create an AWS request object and sign it using SigV4Auth
aws_request = AWSRequest(method=method.upper(), url=url, data=body)
aws_request = AWSRequest(method=method.upper(), url=signature_host, data=body)

# credentials objects expose access_key, secret_key and token attributes
# via @property annotations that call _refresh() on every access,
Expand All @@ -69,6 +73,30 @@ def sign(self, method: str, url: str, body: Any) -> Dict[str, str]:

return headers

@staticmethod
def _fetch_url(url: str, headers: Optional[Dict[str, str]]) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
parsed_url = urlparse(url)
path = parsed_url.path or "/"

# fetch the query string if present in the request
querystring = ""
if parsed_url.query:
querystring = "?" + urlencode(
parse_qs(parsed_url.query, keep_blank_values=True), doseq=True
)

# fetch the host information from headers
headers = {key.lower(): value for key, value in (headers or dict()).items()}
location = headers.get("host") or parsed_url.netloc

# construct the url and return
return parsed_url.scheme + "://" + location + path + querystring


class RequestsAWSV4SignerAuth(requests.auth.AuthBase):
"""
Expand All @@ -89,40 +117,16 @@ def _sign_request(self, prepared_request): # type: ignore
:return: signed request
"""

prepared_request.headers.update(
self.signer.sign(
prepared_request.method,
self._fetch_url(prepared_request),
prepared_request.body,
)
updated_headers = self.signer.sign(
method=prepared_request.method,
url=prepared_request.url,
body=prepared_request.body,
headers=prepared_request.headers,
)

return prepared_request

def _fetch_url(self, prepared_request: requests.PreparedRequest) -> str:
"""
This is a util method that helps in reconstructing the request url.
:param prepared_request: unsigned request
:return: reconstructed url
"""
url = urlparse(prepared_request.url)
path = url.path or "/"

# fetch the query string if present in the request
querystring = ""
if url.query:
querystring = "?" + urlencode(
parse_qs(url.query, keep_blank_values=True), doseq=True # type: ignore
)
prepared_request.headers.update(updated_headers)

# fetch the host information from headers
headers = {
key.lower(): value for key, value in prepared_request.headers.items()
}
location = headers.get("host") or url.netloc

# construct the url and return
return url.scheme + "://" + location + path + querystring # type: ignore
return prepared_request


# Deprecated: use RequestsAWSV4SignerAuth
Expand All @@ -135,5 +139,7 @@ def __init__(self, credentials, region, service: str = "es") -> None: # type: i
self.signer = AWSV4Signer(credentials, region, service)
self.service = service # tools like LangChain rely on this, see https://github.com/opensearch-project/opensearch-py/issues/600

def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]:
return self.signer.sign(method, url, body)
def __call__(
self, method: str, url: str, body: Any, headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
return self.signer.sign(method, url, body, headers)
20 changes: 11 additions & 9 deletions test_opensearchpy/test_async/test_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import uuid
from typing import Any, Collection, Dict, Mapping, Optional, Tuple, Union
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
from _pytest.mark.structures import MarkDecorator
Expand Down Expand Up @@ -81,15 +81,18 @@ async def test_aws_signer_async_fetch_url_with_querystring(self) -> None:
region = "us-west-2"
service = "aoss"

from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth

auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
from botocore.awsrequest import AWSRequest

signature_host = auth._fetch_url(
"http://localhost/?foo=bar", headers={"host": "otherhost"}
)
from opensearchpy.helpers.asyncsigner import AWSV4SignerAsyncAuth

assert signature_host == "http://otherhost/?foo=bar"
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:
auth = AWSV4SignerAsyncAuth(self.mock_session(), region, service)
auth("GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"})
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)


class TestAsyncSignerWithFrozenCredentials(TestAsyncSigner):
Expand Down Expand Up @@ -155,7 +158,6 @@ def _sign_request(
self,
method: str,
url: str,
query_string: Optional[str] = None,
body: Optional[Union[str, bytes]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
Expand Down
29 changes: 18 additions & 11 deletions test_opensearchpy/test_connection/test_requests_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,22 +457,27 @@ def mock_session(self) -> Any:

return dummy_session

def test_aws_signer_fetch_url_with_querystring(self) -> None:
def test_aws_signer_url_with_querystring_and_custom_header(self) -> None:
region = "us-west-2"

import requests
from botocore.awsrequest import AWSRequest

from opensearchpy.helpers.signer import RequestsAWSV4SignerAuth

auth = RequestsAWSV4SignerAuth(self.mock_session(), region)

prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
with patch(
"botocore.awsrequest.AWSRequest", side_effect=AWSRequest
) as mock_aws_request:

signature_host = auth._fetch_url(prepared_request)
auth = RequestsAWSV4SignerAuth(self.mock_session(), region)
prepared_request = requests.Request(
"GET", "http://localhost/?foo=bar", headers={"host": "otherhost:443"}
).prepare()
auth(prepared_request)

assert signature_host == "http://otherhost:443/?foo=bar"
mock_aws_request.assert_called_with(
method="GET", url="http://otherhost:443/?foo=bar", data=None
)

def test_aws_signer_as_http_auth(self) -> None:
region = "us-west-2"
Expand Down Expand Up @@ -525,9 +530,11 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None:
).prepare()
auth(prepared_request)
self.assertEqual(mock_sign.call_count, 1)
self.assertEqual(
mock_sign.call_args[0],
("GET", "http://localhost/?key1=value1&key2=value2", None),
mock_sign.assert_called_with(
method="GET",
url="http://localhost/?key1=value1&key2=value2",
body=None,
headers={},
)

def test_aws_signer_consitent_url(self) -> None:
Expand Down

0 comments on commit 7815c6a

Please sign in to comment.