diff --git a/sdk/core/azure-core/CHANGELOG.md b/sdk/core/azure-core/CHANGELOG.md index 3792824df348..1755ee3f5f35 100644 --- a/sdk/core/azure-core/CHANGELOG.md +++ b/sdk/core/azure-core/CHANGELOG.md @@ -1,7 +1,11 @@ # Release History -## 1.9.1 (Unreleased) +## 1.10.0 (Unreleased) + +### Features + +- Added `AzureSasCredential` and its respective policy. #15946 ## 1.9.0 (2020-11-09) diff --git a/sdk/core/azure-core/azure/core/_version.py b/sdk/core/azure-core/azure/core/_version.py index d2ee06e7dac8..c50ed6ccf7c1 100644 --- a/sdk/core/azure-core/azure/core/_version.py +++ b/sdk/core/azure-core/azure/core/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "1.9.1" +VERSION = "1.10.0" diff --git a/sdk/core/azure-core/azure/core/credentials.py b/sdk/core/azure-core/azure/core/credentials.py index 6103db764cb2..21aeb433ac81 100644 --- a/sdk/core/azure-core/azure/core/credentials.py +++ b/sdk/core/azure-core/azure/core/credentials.py @@ -30,7 +30,7 @@ def get_token(self, *scopes, **kwargs): AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) -__all__ = ["AzureKeyCredential", "AccessToken"] +__all__ = ["AzureKeyCredential", "AzureSasCredential", "AccessToken"] class AzureKeyCredential(object): @@ -71,3 +71,43 @@ def update(self, key): if not isinstance(key, six.string_types): raise TypeError("The key used for updating must be a string.") self._key = key + + +class AzureSasCredential(object): + """Credential type used for authenticating to an Azure service. + It provides the ability to update the shared access signature without creating a new client. + + :param str signature: The shared access signature used to authenticate to an Azure service + :raises: TypeError + """ + + def __init__(self, signature): + # type: (str) -> None + if not isinstance(signature, six.string_types): + raise TypeError("signature must be a string.") + self._signature = signature # type: str + + @property + def signature(self): + # type () -> str + """The value of the configured shared access signature. + + :rtype: str + """ + return self._signature + + def update(self, signature): + # type: (str) -> None + """Update the shared access signature. + + This can be used when you've regenerated your shared access signature and want + to update long-lived clients. + + :param str signature: The shared access signature used to authenticate to an Azure service + :raises: ValueError or TypeError + """ + if not signature: + raise ValueError("The signature used for updating can not be None or empty") + if not isinstance(signature, six.string_types): + raise TypeError("The signature used for updating must be a string.") + self._signature = signature diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py index 6b74f0bc5a1a..a0e81b13cef5 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/__init__.py @@ -25,7 +25,7 @@ # -------------------------------------------------------------------------- from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory -from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy +from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy from ._custom_hook import CustomHookPolicy from ._redirect import RedirectPolicy from ._retry import RetryPolicy, RetryMode @@ -45,6 +45,7 @@ 'SansIOHTTPPolicy', 'BearerTokenCredentialPolicy', 'AzureKeyCredentialPolicy', + 'AzureSasCredentialPolicy', 'HeadersPolicy', 'UserAgentPolicy', 'NetworkTraceLoggingPolicy', diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py index 1e2ba66b07c0..929920033cdf 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import from typing import Any, Dict, Optional - from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential + from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential, AzureSasCredential from azure.core.pipeline import PipelineRequest @@ -114,3 +114,34 @@ def __init__(self, credential, name, **kwargs): # pylint: disable=unused-argume def on_request(self, request): request.http_request.headers[self._name] = self._credential.key + + +class AzureSasCredentialPolicy(SansIOHTTPPolicy): + """Adds a shared access signature to query for the provided credential. + + :param credential: The credential used to authenticate requests. + :type credential: ~azure.core.credentials.AzureSasCredential + :raises: ValueError or TypeError + """ + def __init__(self, credential, **kwargs): # pylint: disable=unused-argument + # type: (AzureSasCredential, **Any) -> None + super(AzureSasCredentialPolicy, self).__init__() + if not credential: + raise ValueError("credential can not be None") + self._credential = credential + + def on_request(self, request): + url = request.http_request.url + query = request.http_request.query + signature = self._credential.signature + if signature.startswith("?"): + signature = signature[1:] + if query: + if signature not in url: + url = url + "&" + signature + else: + if url.endswith("?"): + url = url + signature + else: + url = url + "?" + signature + request.http_request.url = url diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index fb56819b6451..ddfcc314fcd9 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -6,10 +6,10 @@ import time import azure.core -from azure.core.credentials import AccessToken, AzureKeyCredential +from azure.core.credentials import AccessToken, AzureKeyCredential, AzureSasCredential from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import Pipeline -from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy, AzureKeyCredentialPolicy +from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy from azure.core.pipeline.transport import HttpRequest import pytest @@ -190,3 +190,43 @@ def test_azure_key_credential_updates(): api_key = "new" credential.update(api_key) assert credential.key == api_key + +@pytest.mark.parametrize("sas,url,expected_url", [ + ("sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), + ("?sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), + ("sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), + ("?sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), + ("sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), + ("?sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), + ("sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), + ("?sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), +]) +def test_azure_sas_credential_policy(sas, url, expected_url): + """Tests to see if we can create an AzureSasCredentialPolicy""" + + def verify_authorization(request): + assert request.url == expected_url + + transport=Mock(send=verify_authorization) + credential = AzureSasCredential(sas) + credential_policy = AzureSasCredentialPolicy(credential=credential) + pipeline = Pipeline(transport=transport, policies=[credential_policy]) + + pipeline.run(HttpRequest("GET", url)) + +def test_azure_sas_credential_updates(): + """Tests AzureSasCredential updates""" + sas = "original" + + credential = AzureSasCredential(sas) + assert credential.signature == sas + + sas = "new" + credential.update(sas) + assert credential.signature == sas + +def test_azure_sas_credential_policy_raises(): + """Tests AzureSasCredential and AzureSasCredentialPolicy raises with non-string input parameters.""" + sas = 1234 + with pytest.raises(TypeError): + credential = AzureSasCredential(sas)