Skip to content

Commit

Permalink
Add AzureSasCredential (Azure#15946)
Browse files Browse the repository at this point in the history
**In this PR:**
- Add `AzureSasCredential` per Azure/azure-sdk#1954
- `AzureSasCredential` is the name that has been settled on the end of discussion.
- Add `AzureSasCredentialPolicy` that appends SAS to query

**Remarks:**
- Some service (like storage in the Portal) present SAS with leading "?". This has to be stripped before appending
- The validation if serviceUri already contain sas (mentioned [here](Azure/azure-sdk#1954 (comment))) will be responsibility of service clients:
    - the format varies between services (i.e. Event Grid SAS and Storage SAS are vastly different)
    - it would be good to fail fast (at client creation) rather than late (at request send).

**References**
- [.NET PR](Azure/azure-sdk-for-net#17636)
  • Loading branch information
kasobol-msft authored and rakshith91 committed Jan 8, 2021
1 parent 0903f7f commit 673d518
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 7 deletions.
6 changes: 5 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@

# Release History

## 1.9.1 (Unreleased)
## 1.10.0 (Unreleased)

### Features

- Added `AzureSasCredential` and its respective policy. #15946


## 1.9.0 (2020-11-09)
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.9.1"
VERSION = "1.10.0"
42 changes: 41 additions & 1 deletion sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_token(self, *scopes, **kwargs):

AccessToken = namedtuple("AccessToken", ["token", "expires_on"])

__all__ = ["AzureKeyCredential", "AccessToken"]
__all__ = ["AzureKeyCredential", "AzureSasCredential", "AccessToken"]


class AzureKeyCredential(object):
Expand Down Expand Up @@ -71,3 +71,43 @@ def update(self, key):
if not isinstance(key, six.string_types):
raise TypeError("The key used for updating must be a string.")
self._key = key


class AzureSasCredential(object):
"""Credential type used for authenticating to an Azure service.
It provides the ability to update the shared access signature without creating a new client.
:param str signature: The shared access signature used to authenticate to an Azure service
:raises: TypeError
"""

def __init__(self, signature):
# type: (str) -> None
if not isinstance(signature, six.string_types):
raise TypeError("signature must be a string.")
self._signature = signature # type: str

@property
def signature(self):
# type () -> str
"""The value of the configured shared access signature.
:rtype: str
"""
return self._signature

def update(self, signature):
# type: (str) -> None
"""Update the shared access signature.
This can be used when you've regenerated your shared access signature and want
to update long-lived clients.
:param str signature: The shared access signature used to authenticate to an Azure service
:raises: ValueError or TypeError
"""
if not signature:
raise ValueError("The signature used for updating can not be None or empty")
if not isinstance(signature, six.string_types):
raise TypeError("The signature used for updating must be a string.")
self._signature = signature
3 changes: 2 additions & 1 deletion sdk/core/azure-core/azure/core/pipeline/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# --------------------------------------------------------------------------

from ._base import HTTPPolicy, SansIOHTTPPolicy, RequestHistory
from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy
from ._authentication import BearerTokenCredentialPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy
from ._custom_hook import CustomHookPolicy
from ._redirect import RedirectPolicy
from ._retry import RetryPolicy, RetryMode
Expand All @@ -45,6 +45,7 @@
'SansIOHTTPPolicy',
'BearerTokenCredentialPolicy',
'AzureKeyCredentialPolicy',
'AzureSasCredentialPolicy',
'HeadersPolicy',
'UserAgentPolicy',
'NetworkTraceLoggingPolicy',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Dict, Optional
from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential
from azure.core.credentials import AccessToken, TokenCredential, AzureKeyCredential, AzureSasCredential
from azure.core.pipeline import PipelineRequest


Expand Down Expand Up @@ -114,3 +114,34 @@ def __init__(self, credential, name, **kwargs): # pylint: disable=unused-argume

def on_request(self, request):
request.http_request.headers[self._name] = self._credential.key


class AzureSasCredentialPolicy(SansIOHTTPPolicy):
"""Adds a shared access signature to query for the provided credential.
:param credential: The credential used to authenticate requests.
:type credential: ~azure.core.credentials.AzureSasCredential
:raises: ValueError or TypeError
"""
def __init__(self, credential, **kwargs): # pylint: disable=unused-argument
# type: (AzureSasCredential, **Any) -> None
super(AzureSasCredentialPolicy, self).__init__()
if not credential:
raise ValueError("credential can not be None")
self._credential = credential

def on_request(self, request):
url = request.http_request.url
query = request.http_request.query
signature = self._credential.signature
if signature.startswith("?"):
signature = signature[1:]
if query:
if signature not in url:
url = url + "&" + signature
else:
if url.endswith("?"):
url = url + signature
else:
url = url + "?" + signature
request.http_request.url = url
44 changes: 42 additions & 2 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import time

import azure.core
from azure.core.credentials import AccessToken, AzureKeyCredential
from azure.core.credentials import AccessToken, AzureKeyCredential, AzureSasCredential
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy, AzureKeyCredentialPolicy
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy, AzureKeyCredentialPolicy, AzureSasCredentialPolicy
from azure.core.pipeline.transport import HttpRequest

import pytest
Expand Down Expand Up @@ -190,3 +190,43 @@ def test_azure_key_credential_updates():
api_key = "new"
credential.update(api_key)
assert credential.key == api_key

@pytest.mark.parametrize("sas,url,expected_url", [
("sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"),
("?sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"),
("sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"),
("?sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"),
("sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"),
("?sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"),
("sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"),
("?sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"),
])
def test_azure_sas_credential_policy(sas, url, expected_url):
"""Tests to see if we can create an AzureSasCredentialPolicy"""

def verify_authorization(request):
assert request.url == expected_url

transport=Mock(send=verify_authorization)
credential = AzureSasCredential(sas)
credential_policy = AzureSasCredentialPolicy(credential=credential)
pipeline = Pipeline(transport=transport, policies=[credential_policy])

pipeline.run(HttpRequest("GET", url))

def test_azure_sas_credential_updates():
"""Tests AzureSasCredential updates"""
sas = "original"

credential = AzureSasCredential(sas)
assert credential.signature == sas

sas = "new"
credential.update(sas)
assert credential.signature == sas

def test_azure_sas_credential_policy_raises():
"""Tests AzureSasCredential and AzureSasCredentialPolicy raises with non-string input parameters."""
sas = 1234
with pytest.raises(TypeError):
credential = AzureSasCredential(sas)

0 comments on commit 673d518

Please sign in to comment.