Skip to content

Commit

Permalink
[ACR] Fix mypy errors (Azure#27852)
Browse files Browse the repository at this point in the history
  • Loading branch information
YalinLi0312 authored Feb 7, 2023
1 parent d1d6de3 commit a332ad5
Show file tree
Hide file tree
Showing 25 changed files with 233 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import Dict, Any
from typing import Any, Union, Optional

from ._exchange_client import ExchangeClientAuthenticationPolicy
from ._generated import ContainerRegistry
Expand All @@ -23,7 +23,7 @@ class AnonymousACRExchangeClient(object):
"""

def __init__(self, endpoint, **kwargs): # pylint: disable=missing-client-constructor-parameter-credential
# type: (str, Dict[str, Any]) -> None
# type: (str, Any) -> None
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
Expand All @@ -36,21 +36,18 @@ def __init__(self, endpoint, **kwargs): # pylint: disable=missing-client-constr
)

def get_acr_access_token(self, challenge, **kwargs):
# type: (str, Dict[str, Any]) -> str
# type: (str, Any) -> Optional[str]
parsed_challenge = _parse_challenge(challenge)
parsed_challenge["grant_type"] = TokenGrantType.PASSWORD
return self.exchange_refresh_token_for_access_token(
None,
"",
service=parsed_challenge["service"],
scope=parsed_challenge["scope"],
grant_type=TokenGrantType.PASSWORD,
**kwargs
)

def exchange_refresh_token_for_access_token(
self, refresh_token=None, service=None, scope=None, grant_type=TokenGrantType.PASSWORD, **kwargs
):
# type: (str, str, str, str, Dict[str, Any]) -> str
def exchange_refresh_token_for_access_token(self, refresh_token, service, scope, grant_type, **kwargs):
# type: (str, str, str, Union[str, TokenGrantType], Any) -> Optional[str]
access_token = self._client.authentication.exchange_acr_refresh_token_for_acr_access_token(
service=service, scope=scope, refresh_token=refresh_token, grant_type=grant_type, **kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the MIT License.
# ------------------------------------

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union, Optional
from io import SEEK_SET, UnsupportedOperation

from azure.core.pipeline.policies import HTTPPolicy
Expand All @@ -22,11 +22,11 @@ class ContainerRegistryChallengePolicy(HTTPPolicy):
"""Authentication policy for ACR which accepts a challenge"""

def __init__(self, credential, endpoint, **kwargs):
# type: (TokenCredential, str, **Any) -> None
# type: (Optional[TokenCredential], str, **Any) -> None
super(ContainerRegistryChallengePolicy, self).__init__()
self._credential = credential
if self._credential is None:
self._exchange_client = AnonymousACRExchangeClient(endpoint, **kwargs)
self._exchange_client = AnonymousACRExchangeClient(endpoint, **kwargs) # type: Union[AnonymousACRExchangeClient, ACRExchangeClient] # pylint: disable=line-too-long
else:
self._exchange_client = ACRExchangeClient(endpoint, self._credential, **kwargs)

Expand Down Expand Up @@ -77,7 +77,8 @@ def on_challenge(self, request, response, challenge):
# pylint:disable=unused-argument,no-self-use

access_token = self._exchange_client.get_acr_access_token(challenge)
request.http_request.headers["Authorization"] = "Bearer " + access_token
if access_token is not None:
request.http_request.headers["Authorization"] = "Bearer " + access_token
return access_token is not None

def __enter__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the MIT License.
# ------------------------------------
from enum import Enum
from typing import TYPE_CHECKING, Dict, Any, Optional
from typing import TYPE_CHECKING, Any, Optional

from azure.core import CaseInsensitiveEnumMeta
from azure.core.pipeline.transport import HttpTransport
Expand Down Expand Up @@ -38,7 +38,7 @@ class ContainerRegistryBaseClient(object):
"""

def __init__(self, endpoint, credential, **kwargs):
# type: (str, Optional[TokenCredential], Dict[str, Any]) -> None
# type: (str, Optional[TokenCredential], Any) -> None
self._auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs)
self._client = ContainerRegistry(
credential=credential,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the MIT License.
# ------------------------------------
from io import BytesIO
from typing import TYPE_CHECKING, Any, IO, Optional, overload, Union
from typing import TYPE_CHECKING, Any, Dict, IO, Optional, overload, Union, cast, Tuple
from azure.core.exceptions import (
ClientAuthenticationError,
ResourceNotFoundError,
Expand All @@ -13,10 +13,11 @@
map_error,
)
from azure.core.paging import ItemPaged
from azure.core.pipeline import PipelineResponse
from azure.core.tracing.decorator import distributed_trace

from ._base_client import ContainerRegistryBaseClient
from ._generated.models import AcrErrors, OCIManifest
from ._generated.models import AcrErrors, OCIManifest, ManifestWrapper
from ._helpers import (
_compute_digest,
_is_tag,
Expand All @@ -37,10 +38,15 @@

if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from typing import Dict

def _return_response(pipeline_response, deserialized, response_headers):
return pipeline_response, deserialized, response_headers
def _return_response_and_deserialized(pipeline_response, deserialized, _):
return pipeline_response, deserialized

def _return_deserialized(_, deserialized, __):
return deserialized

def _return_response_headers(_, __, response_headers):
return response_headers


class ContainerRegistryClient(ContainerRegistryBaseClient):
Expand Down Expand Up @@ -141,7 +147,7 @@ def list_repository_names(self, **kwargs):
"""
n = kwargs.pop("results_per_page", None)
last = kwargs.pop("last", None)
cls = kwargs.pop("cls", None) # type: ClsType["_models.Repositories"]
cls = kwargs.pop("cls", None)
error_map = {401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError}
error_map.update(kwargs.pop("error_map", {}))
accept = "application/json"
Expand Down Expand Up @@ -618,11 +624,11 @@ def update_manifest_properties(self, *args, **kwargs):
can_write=False,
)
"""
repository = args[0]
tag_or_digest = args[1]
repository = str(args[0])
tag_or_digest = str(args[1])
properties = None
if len(args) == 3:
properties = args[2]
properties = cast(ArtifactManifestProperties, args[2])
else:
properties = ArtifactManifestProperties()

Expand Down Expand Up @@ -692,11 +698,11 @@ def update_tag_properties(self, *args, **kwargs):
can_write=False,
)
"""
repository = args[0]
tag = args[1]
repository = str(args[0])
tag = str(args[1])
properties = None
if len(args) == 3:
properties = args[2]
properties = cast(ArtifactTagProperties, args[2])
else:
properties = ArtifactTagProperties()

Expand All @@ -709,7 +715,7 @@ def update_tag_properties(self, *args, **kwargs):
self._client.container_registry.update_tag_attributes(
repository, tag, value=properties._to_generated(), **kwargs # pylint: disable=protected-access
),
repository=repository,
repository=repository
)

@overload
Expand Down Expand Up @@ -740,12 +746,11 @@ def update_repository_properties(self, *args, **kwargs):
:rtype: ~azure.containerregistry.RepositoryProperties
:raises: ~azure.core.exceptions.ResourceNotFoundError
"""
repository, properties = None, None
repository = str(args[0])
properties = None
if len(args) == 2:
repository = args[0]
properties = args[1]
properties = cast(RepositoryProperties, args[1])
else:
repository = args[0]
properties = RepositoryProperties()

properties.can_delete = kwargs.pop("can_delete", properties.can_delete)
Expand Down Expand Up @@ -777,20 +782,21 @@ def upload_manifest(
If the digest in the response does not match the digest of the uploaded manifest.
"""
try:
data = manifest
if isinstance(manifest, OCIManifest):
data = _serialize_manifest(manifest)
else:
data = manifest
tag_or_digest = tag
if tag is None:
if tag_or_digest is None:
tag_or_digest = _compute_digest(data)

_, _, response_headers = self._client.container_registry.create_manifest(
response_headers = self._client.container_registry.create_manifest(
name=repository,
reference=tag_or_digest,
payload=data,
content_type=OCI_MANIFEST_MEDIA_TYPE,
headers={"Accept": OCI_MANIFEST_MEDIA_TYPE},
cls=_return_response,
cls=_return_response_headers,
**kwargs
)

Expand All @@ -817,15 +823,24 @@ def upload_blob(self, repository, data, **kwargs):
:raises ValueError: If the parameter repository or data is None.
"""
try:
_, _, start_upload_response_headers = self._client.container_registry_blob.start_upload(
repository, cls=_return_response, **kwargs
)
_, _, upload_chunk_response_headers = self._client.container_registry_blob.upload_chunk(
start_upload_response_headers['Location'], data, cls=_return_response, **kwargs
)
start_upload_response_headers = cast(Dict[str, str], self._client.container_registry_blob.start_upload(
repository, cls=_return_response_headers, **kwargs
))
upload_chunk_response_headers = cast(Dict[str, str], self._client.container_registry_blob.upload_chunk(
start_upload_response_headers['Location'],
data,
cls=_return_response_headers,
**kwargs
))
digest = _compute_digest(data)
_, _, complete_upload_response_headers = self._client.container_registry_blob.complete_upload(
digest=digest, next_link=upload_chunk_response_headers['Location'], cls=_return_response, **kwargs
complete_upload_response_headers = cast(
Dict[str, str],
self._client.container_registry_blob.complete_upload(
digest=digest,
next_link=upload_chunk_response_headers['Location'],
cls=_return_response_headers,
**kwargs
)
)
except ValueError:
if repository is None or data is None:
Expand All @@ -847,15 +862,18 @@ def download_manifest(self, repository, tag_or_digest, **kwargs):
If the requested digest does not match the digest of the received manifest.
"""
try:
response, manifest_wrapper, _ = self._client.container_registry.get_manifest(
name=repository,
reference=tag_or_digest,
headers={"Accept": OCI_MANIFEST_MEDIA_TYPE},
cls=_return_response,
**kwargs
response, manifest_wrapper = cast(
Tuple[PipelineResponse, ManifestWrapper],
self._client.container_registry.get_manifest(
name=repository,
reference=tag_or_digest,
headers={"Accept": OCI_MANIFEST_MEDIA_TYPE},
cls=_return_response_and_deserialized,
**kwargs
)
)
digest = response.http_response.headers['Docker-Content-Digest']
manifest = OCIManifest.deserialize(manifest_wrapper.serialize())
manifest = OCIManifest.deserialize(cast(ManifestWrapper, manifest_wrapper).serialize())
manifest_stream = _serialize_manifest(manifest)
except ValueError:
if repository is None or tag_or_digest is None:
Expand All @@ -878,8 +896,8 @@ def download_blob(self, repository, digest, **kwargs):
:raises ValueError: If the parameter repository or digest is None.
"""
try:
_, deserialized, _ = self._client.container_registry_blob.get_blob(
repository, digest, cls=_return_response, **kwargs
deserialized = self._client.container_registry_blob.get_blob( # type: ignore
repository, digest, cls=_return_deserialized, **kwargs
)
except ValueError:
if repository is None or digest is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import TYPE_CHECKING, Dict, Any
from typing import TYPE_CHECKING, Any, Optional

from azure.core.pipeline.policies import SansIOHTTPPolicy

Expand Down Expand Up @@ -43,7 +43,7 @@ class ACRExchangeClient(object):
"""

def __init__(self, endpoint, credential, **kwargs):
# type: (str, TokenCredential, Dict[str, Any]) -> None
# type: (str, TokenCredential, Any) -> None
if not endpoint.startswith("https://") and not endpoint.startswith("http://"):
endpoint = "https://" + endpoint
self._endpoint = endpoint
Expand All @@ -56,36 +56,36 @@ def __init__(self, endpoint, credential, **kwargs):
**kwargs
)
self._credential = credential
self._refresh_token = None
self._expiration_time = 0
self._refresh_token = None # type: Optional[str]
self._expiration_time = 0 # type: float

def get_acr_access_token(self, challenge, **kwargs):
# type: (str, Dict[str, Any]) -> str
# type: (str, Any) -> Optional[str]
parsed_challenge = _parse_challenge(challenge)
refresh_token = self.get_refresh_token(parsed_challenge["service"], **kwargs)
return self.exchange_refresh_token_for_access_token(
refresh_token, service=parsed_challenge["service"], scope=parsed_challenge["scope"], **kwargs
)

def get_refresh_token(self, service, **kwargs):
# type: (str, Dict[str, Any]) -> str
# type: (str, Any) -> str
if not self._refresh_token or self._expiration_time - time.time() > 300:
self._refresh_token = self.exchange_aad_token_for_refresh_token(service, **kwargs)
self._expiration_time = _parse_exp_time(self._refresh_token)
return self._refresh_token

def exchange_aad_token_for_refresh_token(self, service=None, **kwargs):
# type: (str, Dict[str, Any]) -> str
def exchange_aad_token_for_refresh_token(self, service, **kwargs):
# type: (str, Any) -> str
refresh_token = self._client.authentication.exchange_aad_access_token_for_acr_refresh_token(
grant_type=PostContentSchemaGrantType.ACCESS_TOKEN,
service=service,
access_token=self._credential.get_token(*self.credential_scopes).token,
**kwargs
)
return refresh_token.refresh_token
return refresh_token.refresh_token if refresh_token.refresh_token is not None else ""

def exchange_refresh_token_for_access_token(self, refresh_token, service=None, scope=None, **kwargs):
# type: (str, str, str, Dict[str, Any]) -> str
def exchange_refresh_token_for_access_token(self, refresh_token, service, scope, **kwargs):
# type: (str, str, str, Any) -> Optional[str]
access_token = self._client.authentication.exchange_acr_refresh_token_for_acr_access_token(
service=service, scope=scope, refresh_token=refresh_token, **kwargs
)
Expand Down
Loading

0 comments on commit a332ad5

Please sign in to comment.