Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run mypy in azure-identity CI #15832

Merged
merged 5 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eng/tox/mypy_hard_failure_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
MYPY_HARD_FAILURE_OPTED = [
"azure-core",
"azure-eventhub",
"azure-identity",
"azure-servicebus",
"azure-ai-textanalytics",
"azure-ai-formrecognizer",
Expand Down
15 changes: 8 additions & 7 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from time import struct_time
from typing import Any, Dict, Iterable, Mapping, Optional, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.transport import HttpTransport
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy

PolicyListType = List[Union[HTTPPolicy, SansIOHTTPPolicy]]


class AuthnClientBase(ABC):
Expand Down Expand Up @@ -166,10 +168,9 @@ def _parse_app_service_expires_on(expires_on):

raise ValueError("'{}' doesn't match the expected format".format(expires_on))

# TODO: public, factor out of request_token
def _prepare_request(
self,
method="POST", # type: Optional[str]
method="POST", # type: str
headers=None, # type: Optional[Mapping[str, str]]
form_data=None, # type: Optional[Mapping[str, str]]
params=None, # type: Optional[Dict[str, str]]
Expand Down Expand Up @@ -200,7 +201,7 @@ class AuthnClient(AuthnClientBase):
def __init__(
self,
config=None, # type: Optional[Configuration]
policies=None, # type: Optional[Iterable[HTTPPolicy]]
policies=None, # type: Optional[PolicyListType]
transport=None, # type: Optional[HttpTransport]
**kwargs # type: Any
):
Expand All @@ -217,13 +218,13 @@ def __init__(
]
if not transport:
transport = RequestsTransport(**kwargs)
self._pipeline = Pipeline(transport=transport, policies=policies)
self._pipeline = Pipeline(transport=transport, policies=policies) # type: Pipeline
super(AuthnClient, self).__init__(**kwargs)

def request_token(
self,
scopes, # type: Iterable[str]
method="POST", # type: Optional[str]
method="POST", # type: str
headers=None, # type: Optional[Mapping[str, str]]
form_data=None, # type: Optional[Mapping[str, str]]
params=None, # type: Optional[Dict[str, str]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def __init__(self, **kwargs):

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import (
DistributedTracingPolicy,
HttpLoggingPolicy,
Expand All @@ -28,6 +27,7 @@
from typing import Any, List, Optional, Union
from azure.core.configuration import Configuration
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]
Expand All @@ -40,24 +40,21 @@ def __init__(self, **kwargs):

url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
if not (url and imds):
# Azure Arc managed identity isn't available in this environment
self._client = None
return

identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()

self._client = ManagedIdentityClient(
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
**kwargs
)
self._available = url and imds
if self._available:
identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()

self._client = ManagedIdentityClient(
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
**kwargs
)

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Azure Arc managed identity configuration not found in environment"
)
Expand Down Expand Up @@ -125,7 +122,7 @@ class ArcChallengeAuthPolicy(HTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""

def send(self, request):
# type: (PipelineRequest) -> HttpResponse
# type: (PipelineRequest) -> PipelineResponse
request.http_request.headers["Metadata"] = "true"
response = self.next.send(request)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken
from typing import Any, List
from azure.core.credentials import AccessToken, TokenCredential

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, **kwargs):
exclude_cli_credential = kwargs.pop("exclude_cli_credential", False)
exclude_interactive_browser_credential = kwargs.pop("exclude_interactive_browser_credential", True)

credentials = []
credentials = [] # type: List[TokenCredential]
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, **kwargs))
if not exclude_managed_identity_credential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Type
from azure.core.credentials import TokenCredential

_LOGGER = logging.getLogger(__name__)

Expand All @@ -52,7 +53,7 @@ class ManagedIdentityCredential(object):

def __init__(self, **kwargs):
# type: (**Any) -> None
self._credential = None
self._credential = None # type: Optional[TokenCredential]
if os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
if os.environ.get(EnvironmentVariables.MSI_SECRET):
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def __init__(self, **kwargs):

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

# self._auth_record and ._app will not be None when this method is called by get_token
# but should either be None anyway (and to satisfy mypy) we raise
if self._app is None or self._auth_record is None:
raise CredentialUnavailableError("Initialization failed")

result = None

accounts_for_user = self._app.get_accounts(username=self._auth_record.username)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def get_default_authority():


def validate_tenant_id(tenant_id):
"""Raise ValueError if tenant_id is empty or contains a character invalid for a tenant id"""
# type: (str) -> None
"""Raise ValueError if tenant_id is empty or contains a character invalid for a tenant id"""
if not tenant_id or any(c not in VALID_TENANT_ID_CHARACTERS for c in tenant_id):
raise ValueError(
"Invalid tenant id provided. You can locate your tenant id by following the instructions here: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class GetTokenMixin(ABC):
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
self._last_request_time = 0
super(GetTokenMixin, self).__init__(*args, **kwargs)

# https://github.com/python/mypy/issues/5887
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore

@abc.abstractmethod
def _acquire_token_silently(self, *scopes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class _SECRET_SCHEMA(ct.Structure):
_libsecret.secret_password_lookup_sync.restype = ct.c_char_p
_libsecret.secret_password_free.argtypes = [ct.c_char_p]
except OSError:
_libsecret = None
_libsecret = None # type: ignore


def _get_user_settings_path():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]


class ManagedIdentityClient(object):
class ManagedIdentityClientBase(ABC):
# pylint:disable=missing-client-constructor-parameter-credential
def __init__(self, request_factory, client_id=None, **kwargs):
# type: (Callable[[str, dict], HttpRequest], Optional[str], **Any) -> None
Expand All @@ -55,24 +55,6 @@ def __init__(self, request_factory, client_id=None, **kwargs):

self._request_factory = request_factory

def get_cached_token(self, *scopes):
# type: (*str) -> Optional[AccessToken]
resource = _scopes_to_resource(*scopes)
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource])
for token in tokens:
if token["expires_on"] > time.time():
return AccessToken(token["secret"], token["expires_on"])
return None

def request_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
resource = _scopes_to_resource(*scopes)
request = self._request_factory(resource, self._identity_config)
request_time = int(time.time())
response = self._pipeline.run(request)
token = self._process_response(response, request_time)
return token

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken

Expand Down Expand Up @@ -102,6 +84,34 @@ def _process_response(self, response, request_time):

return token

def get_cached_token(self, *scopes):
# type: (*str) -> Optional[AccessToken]
resource = _scopes_to_resource(*scopes)
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource])
for token in tokens:
if token["expires_on"] > time.time():
return AccessToken(token["secret"], token["expires_on"])
return None

@abc.abstractmethod
def request_token(self, *scopes, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, config, policies=None, transport=None, **kwargs):
pass


class ManagedIdentityClient(ManagedIdentityClientBase):
def request_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
resource = _scopes_to_resource(*scopes)
request = self._request_factory(resource, self._identity_config)
request_time = int(time.time())
response = self._pipeline.run(request)
token = self._process_response(response, request_time)
return token

def _build_pipeline(self, config, policies=None, transport=None, **kwargs): # pylint:disable=no-self-use
# type: (Configuration, Optional[List[PolicyType]], Optional[HttpTransport], **Any) -> Pipeline
if policies is None: # [] is a valid policy list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class _CREDENTIAL(ct.Structure):

_PCREDENTIAL = ct.POINTER(_CREDENTIAL)

_advapi = ct.WinDLL("advapi32")
_advapi = ct.WinDLL("advapi32") # type: ignore
_advapi.CredReadW.argtypes = [wt.LPCWSTR, wt.DWORD, wt.DWORD, ct.POINTER(_PCREDENTIAL)]
_advapi.CredReadW.restype = wt.BOOL
_advapi.CredFree.argtypes = [_PCREDENTIAL]
Expand Down
12 changes: 7 additions & 5 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from .._internal.user_agent import USER_AGENT

if TYPE_CHECKING:
from typing import Any, Dict, Iterable, Mapping, Optional
from azure.core.pipeline.policies import HTTPPolicy
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport

PolicyListType = List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]]


class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name
"""Async authentication client"""
Expand All @@ -35,7 +37,7 @@ class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name
def __init__(
self,
config: "Optional[Configuration]" = None,
policies: "Optional[Iterable[HTTPPolicy]]" = None,
policies: "Optional[PolicyListType]" = None,
transport: "Optional[AsyncHttpTransport]" = None,
**kwargs: "Any"
) -> None:
Expand All @@ -51,7 +53,7 @@ def __init__(
]
if not transport:
transport = AioHttpTransport(**kwargs)
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
self._pipeline = AsyncPipeline(transport=transport, policies=policies) # type: AsyncPipeline
super().__init__(**kwargs)

async def __aenter__(self):
Expand All @@ -67,7 +69,7 @@ async def close(self) -> None:
async def request_token( # pylint:disable=invalid-overridden-method
self,
scopes: "Iterable[str]",
method: "Optional[str]" = "POST",
method: str = "POST",
headers: "Optional[Mapping[str, str]]" = None,
form_data: "Optional[Mapping[str, str]]" = None,
params: "Optional[Dict[str, str]]" = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ def __init__(self, **kwargs: "Any") -> None:

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = AsyncManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

async def get_token( # pylint:disable=invalid-overridden-method
self, *scopes: str, **kwargs: "Any"
) -> "AccessToken":
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
Expand Down
Loading