Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Translation] Add support for AAD authentication #34883

Merged
merged 12 commits into from
Mar 25, 2024
1 change: 1 addition & 0 deletions sdk/translation/azure-ai-translation-text/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.0.0b2 (Unreleased)

### Features Added
- Added support for AAD authentication.

### Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdk/translation/azure-ai-translation-text/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/translation/azure-ai-translation-text",
"Tag": "python/translation/azure-ai-translation-text_498977d118"
"Tag": "python/translation/azure-ai-translation-text_afde2bdc8c"
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._patch import TextTranslationClient
from ._patch import TextTranslationClient, TranslatorCredential, TranslatorAADCredential
from ._version import VERSION

__version__ = VERSION


from ._patch import TranslatorCredential
from ._patch import patch_sdk as _patch_sdk

__all__ = [
"TranslatorCredential",
"TranslatorAADCredential",
"TextTranslationClient",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------

from typing import Union, Optional
from typing import Union, Optional, Any
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import SansIOHTTPPolicy, BearerTokenCredentialPolicy, AzureKeyCredentialPolicy
from azure.core.credentials import TokenCredential, AzureKeyCredential
Expand Down Expand Up @@ -56,6 +56,37 @@ def on_request(self, request: PipelineRequest) -> None:
request.http_request.headers["Ocp-Apim-Subscription-Key"] = self.credential.key
request.http_request.headers["Ocp-Apim-Subscription-Region"] = self.credential.region

class TranslatorAADCredential:
"""Credential for Translator Service when using AAD authentication.
:param tokenCredential: An object which can provide an access token for the Translator Resource, such as a credential from
:mod:`azure.identity`
:type tokenCredential: ~azure.core.credentials.TokenCredential
:param str resourceId: Azure Resource Id of the Translation Resource.
:param str region: Azure Region of the Translation Resource.
"""

def __init__(self, tokenCredential: TokenCredential, resourceId: str, region: str) -> None:
self.tokenCredential = tokenCredential
self.resourceId = resourceId
self.region = region

class TranslatorAADAuthenticationPolicy(BearerTokenCredentialPolicy):
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
"""Translator AAD Authentication Policy. Adds headers that are required by Translator Service
when global endpoint is used with AAD policy.
Ocp-Apim-Subscription-Region header contains region of the Translator resource.
Ocp-Apim-ResourceId header contains Azure resource Id - Translator resource.
:param credential: Translator AAD Credentials used to access Translator Resource for global Translator endpoint.
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
:type tokenCredential: ~azure.core.credentials.TokenCredential
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, credential: TranslatorAADCredential, **kwargs: Any)-> None:
super(TranslatorAADAuthenticationPolicy, self).__init__(credential.tokenCredential, "https://cognitiveservices.azure.com/.default", **kwargs)
self.translatorCredential = credential

def on_request(self, request: PipelineRequest) -> None:
request.http_request.headers["Ocp-Apim-ResourceId"] = self.translatorCredential.resourceId
request.http_request.headers["Ocp-Apim-Subscription-Region"] = self.translatorCredential.region
super().on_request(request)

def get_translation_endpoint(endpoint, api_version):
if not endpoint:
Expand All @@ -74,6 +105,9 @@ def set_authentication_policy(credential, kwargs):
if isinstance(credential, TranslatorCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = TranslatorAuthenticationPolicy(credential)
elif isinstance(credential, TranslatorAADCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = TranslatorAADAuthenticationPolicy(credential)
elif isinstance(credential, AzureKeyCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = AzureKeyCredentialPolicy(
Expand Down Expand Up @@ -122,7 +156,7 @@ class TextTranslationClient(ServiceClientGenerated):
https://api.cognitive.microsofttranslator.com). Required.
:type endpoint: str
:param credential: Credential used to authenticate with the Translator service
:type credential: Union[AzureKeyCredential , TokenCredential , TranslatorCredential]
:type credential: Union[AzureKeyCredential , TokenCredential , TranslatorCredential, TranslatorAADCredential]
catalinaperalta marked this conversation as resolved.
Show resolved Hide resolved
:keyword api_version: Default value is "3.0". Note that overriding this default value may
result in unsupported behavior.
:paramtype api_version: str
Expand All @@ -131,7 +165,7 @@ class TextTranslationClient(ServiceClientGenerated):
def __init__(
self,
*,
credential: Optional[Union[AzureKeyCredential, TokenCredential, TranslatorCredential]] = None,
credential: Optional[Union[AzureKeyCredential, TokenCredential, TranslatorCredential, TranslatorAADCredential]] = None,
endpoint: Optional[str] = None,
api_version="3.0",
**kwargs
Expand All @@ -144,4 +178,4 @@ def __init__(
super().__init__(endpoint=translation_endpoint, api_version=api_version, **kwargs)


__all__ = ["TextTranslationClient", "TranslatorCredential"]
__all__ = ["TextTranslationClient", "TranslatorCredential", "TranslatorAADCredential"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from ._patch import TextTranslationClient
from ._patch import TextTranslationClient, AsyncTranslatorAADCredential


from ._patch import patch_sdk as _patch_sdk

__all__ = [
"TextTranslationClient",
"AsyncTranslatorAADCredential"
]


Expand Down
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""
from typing import Union, Optional
from typing import Union, Optional, Any
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AzureKeyCredentialPolicy
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
Expand All @@ -24,11 +25,46 @@ def patch_sdk():
https://aka.ms/azsdk/python/dpcodegen/python/customize
"""

class AsyncTranslatorAADCredential:
"""Credential for Translator Service when using AAD authentication.
:param tokenCredential: An object which can provide an access token for the Translator Resource, such as a credential from
:mod:`azure.identity`
:type tokenCredential: ~azure.core.credentials.TokenCredential
:param str resourceId: Azure Resource Id of the Translation Resource.
:param str region: Azure Region of the Translation Resource.
"""

def __init__(self, tokenCredential: AsyncTokenCredential, resourceId: str, region: str) -> None:
self.tokenCredential = tokenCredential
self.resourceId = resourceId
self.region = region

class AsyncTranslatorAADAuthenticationPolicy(AsyncBearerTokenCredentialPolicy):
"""Translator AAD Authentication Policy. Adds headers that are required by Translator Service
when global endpoint is used with AAD policy.
Ocp-Apim-Subscription-Region header contains region of the Translator resource.
Ocp-Apim-ResourceId header contains Azure resource Id - Translator resource.
:param credential: Translator AAD Credentials used to access Translator Resource for global Translator endpoint.
:type tokenCredential: ~azure.core.credentials.TokenCredential
MikeyMCZ marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, credential: AsyncTranslatorAADCredential, **kwargs: Any)-> None:
super(AsyncTranslatorAADAuthenticationPolicy, self).__init__(credential.tokenCredential, "https://cognitiveservices.azure.com/.default", **kwargs)
self.translatorCredential = credential

async def on_request(self, request: PipelineRequest) -> None:
request.http_request.headers["Ocp-Apim-ResourceId"] = self.translatorCredential.resourceId
request.http_request.headers["Ocp-Apim-Subscription-Region"] = self.translatorCredential.region
await super().on_request(request)


def set_authentication_policy(credential, kwargs):
if isinstance(credential, TranslatorCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = TranslatorAuthenticationPolicy(credential)
elif isinstance(credential, AsyncTranslatorAADCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = AsyncTranslatorAADAuthenticationPolicy(credential)
elif isinstance(credential, AzureKeyCredential):
if not kwargs.get("authentication_policy"):
kwargs["authentication_policy"] = AzureKeyCredentialPolicy(
Expand Down Expand Up @@ -77,15 +113,15 @@ class TextTranslationClient(ServiceClientGenerated):
https://api.cognitive.microsofttranslator.com). Required.
:type endpoint: str
:param credential: Credential used to authenticate with the Translator service
:type credential: Union[AzureKeyCredential , AsyncTokenCredential , TranslatorCredential]
:type credential: Union[AzureKeyCredential , AsyncTokenCredential , TranslatorCredential, AsyncTranslatorAADCredential]
:keyword api_version: Default value is "3.0". Note that overriding this default value may
result in unsupported behavior.
:paramtype api_version: str
"""

def __init__(
self,
credential: Union[AzureKeyCredential, AsyncTokenCredential, TranslatorCredential],
credential: Union[AzureKeyCredential, AsyncTokenCredential, TranslatorCredential, AsyncTranslatorAADCredential],
*,
endpoint: Optional[str] = None,
api_version="3.0",
Expand All @@ -99,4 +135,4 @@ def __init__(
super().__init__(endpoint=translation_endpoint, api_version=api_version, **kwargs)


__all__ = ["TextTranslationClient"]
__all__ = ["TextTranslationClient", "AsyncTranslatorAADCredential"]
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-e ../../../tools/azure-sdk-tools
-e ../../../tools/azure-devtools
-e ../../core/azure-core
-e ../../identity/azure-identity
msrestazure
aiohttp>=3.0
4 changes: 4 additions & 0 deletions sdk/translation/azure-ai-translation-text/tests/preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@
text_translation_custom_endpoint="https://fakeCustomEndpoint.cognitiveservices.azure.com",
text_translation_apikey="fakeapikey",
text_translation_region="fakeregion",
text_translation_aadClientId="fakeAADClientId",
text_translation_aadTenantId="fakeAADTenantId",
text_translation_aadSecret="fakeAADSecret",
text_translation_aadResourceId="fakeResourceId"
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_autodetect(self, **kwargs):
response = client.find_sentence_boundaries(request_body=input_text_elements)
assert response is not None
assert response[0].detected_language.language == "en"
assert response[0].detected_language.score == 1
assert response[0].detected_language.score > 0.9
assert response[0].sent_len[0] == 11

@TextTranslationPreparer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def test_autodetect(self, **kwargs):
response = await client.find_sentence_boundaries(request_body=input_text_elements)
assert response is not None
assert response[0].detected_language.language == "en"
assert response[0].detected_language.score == 1
assert response[0].detected_language.score > 0.9
assert response[0].sent_len[0] == 11

@TextTranslationPreparer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,24 @@ def test_token(self, **kwargs):
assert len(response[0].translations) == 1
assert response[0].detected_language.language == "en"
assert response[0].detected_language.score == 1

@pytest.mark.live_test_only
@TextTranslationPreparer()
@recorded_by_proxy
def test_translate_aad(self, **kwargs):
aadRegion = "westus2"
aadResourceId = kwargs.get("text_translation_aadresourceid")
token_credential = self.get_mt_credential(False)
client = self.create_text_translation_client_with_aad(token_credential, aadRegion, aadResourceId)

source_language = "es"
target_languages = ["cs"]
input_text_elements = ["Hola mundo"]
response = client.translate(
request_body=input_text_elements, to=target_languages, from_parameter=source_language
)

assert len(response) == 1
assert len(response[0].translations) == 1
assert response[0].translations[0].to == "cs"
assert response[0].translations[0].text is not None
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------

import os
import pytest
from devtools_testutils.aio import recorded_by_proxy_async
from azure.ai.translation.text.models import TextType, ProfanityAction, ProfanityMarker
Expand Down Expand Up @@ -324,3 +325,24 @@ async def test_token(self, **kwargs):
assert len(response[0].translations) == 1
assert response[0].detected_language.language == "en"
assert response[0].detected_language.score == 1

@pytest.mark.live_test_only
@TextTranslationPreparer()
@recorded_by_proxy_async
async def test_translate_aad(self, **kwargs):
aadRegion = "westus2"
aadResourceId = kwargs.get("text_translation_aadresourceid")
token_credential = self.get_mt_credential(True)
client = self.create_async_text_translation_client_with_aad(token_credential, aadRegion, aadResourceId)

source_language = "es"
target_languages = ["cs"]
input_text_elements = ["Hola mundo"]
response = await client.translate(
request_body=input_text_elements, to=target_languages, from_parameter=source_language
)

assert len(response) == 1
assert len(response[0].translations) == 1
assert response[0].translations[0].to == "cs"
assert response[0].translations[0].text is not None
38 changes: 36 additions & 2 deletions sdk/translation/azure-ai-translation-text/tests/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# Licensed under the MIT License.
# ------------------------------------

import os
from devtools_testutils.fake_credentials_async import AsyncFakeCredential
from azure.core.credentials import AccessToken
from devtools_testutils import AzureRecordedTestCase
from azure.ai.translation.text import TextTranslationClient, TranslatorCredential
from azure.ai.translation.text import TextTranslationClient, TranslatorCredential, TranslatorAADCredential

from static_access_token_credential import StaticAccessTokenCredential


class TextTranslationTest(AzureRecordedTestCase):
def create_getlanguage_client(self, endpoint):
client = TextTranslationClient(endpoint=endpoint, credential=None)
Expand All @@ -23,6 +25,11 @@ def create_client_token(self, endpoint, apikey, region):
credential = StaticAccessTokenCredential(apikey, region)
client = TextTranslationClient(endpoint=endpoint, credential=credential)
return client

def create_text_translation_client_with_aad(self, innerCredential, aadRegion, aadResourceId):
credential = TranslatorAADCredential(innerCredential, aadResourceId, aadRegion)
text_translator = TextTranslationClient(credential=credential)
return text_translator

def create_async_getlanguage_client(self, endpoint):
from azure.ai.translation.text.aio import TextTranslationClient as TextTranslationClientAsync
Expand All @@ -43,3 +50,30 @@ def create_async_client_token(self, endpoint, apikey, region):

client = TextTranslationClientAsync(endpoint=endpoint, credential=credential)
return client

def create_async_text_translation_client_with_aad(self, innerCredential, aadRegion, aadResourceId):
from azure.ai.translation.text.aio import TextTranslationClient as TextTranslationClientAsync, AsyncTranslatorAADCredential
credential = AsyncTranslatorAADCredential(innerCredential, aadResourceId, aadRegion)
text_translator = TextTranslationClientAsync(credential=credential)
return text_translator

def get_mt_credential(self, is_async, **kwargs):
# Return live credentials only in live mode
if self.is_live:
from azure.identity import ClientSecretCredential

if is_async:
from azure.identity.aio import ClientSecretCredential


tenant_id = os.environ.get("AZURE_TENANT_ID", getattr(os.environ, "TENANT_ID", None))
client_id = os.environ.get("AZURE_CLIENT_ID", getattr(os.environ, "CLIENT_ID", None))
secret = os.environ.get("AZURE_CLIENT_SECRET", getattr(os.environ, "CLIENT_SECRET", None))
return ClientSecretCredential(tenant_id=tenant_id, client_id=client_id, client_secret=secret)

# For playback tests, return credentials that will accept playback `get_token` calls
else:
if is_async:
return AsyncFakeCredential()
else:
return self.settings.get_azure_core_credentials()
5 changes: 5 additions & 0 deletions sdk/translation/test-resources.json
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"azureDocTranslationUrl": "[if(parameters('useStaticStorageResource'), concat('https://', variables('docTranslationBaseName'), '.ppe', parameters('cognitiveServicesEndpointSuffix')), concat('https://', variables('docTranslationBaseName'), parameters('cognitiveServicesEndpointSuffix')))]",
"azureStorageEndpoint": "[concat('https://', parameters('blobStorageAccount'), '.blob.core.windows.net/')]",
"cognitiveServiceUserRoleId": "[concat('/subscriptions/', subscription().subscriptionId, '/providers/Microsoft.Authorization/roleDefinitions/a97b65f3-24c7-4388-baec-2e87135dc908')]",
"txtAadResourceIdValue": "[parameters('text_translation_aadResourceId')]",
"encryption": {
"services": {
"blob": {
Expand Down Expand Up @@ -379,6 +380,10 @@
"TEXT_TRANSLATION_REGION": {
"type": "string",
"value": "[variables('txtRegionValue')]"
},
"TEXT_TRANSLATION_AAD_RESOURCE_ID": {
"type": "string",
"value": "[variables('txtAadResourceIdValue')]"
}
}
}
Loading