Skip to content

Commit 1afb899

Browse files
msyyciscai-msft
authored andcommitted
[Python] pass auth flows into credential policy (#6549)
fix #6448 pending on new release of corehttp: Azure/azure-sdk-for-python#40084 SDK API diff is here: Azure/autorest.python#3062 --------- Co-authored-by: iscai-msft <isabellavcai@gmail.com>
1 parent 1229b39 commit 1afb899

File tree

6 files changed

+55
-4
lines changed

6 files changed

+55
-4
lines changed

packages/http-client-python/emitter/src/types.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ export function getType(
9292
case "enumvalue":
9393
return emitEnumMember(type, emitEnum(context, type.enumType));
9494
case "credential":
95-
return emitCredential(type);
95+
return emitCredential(context, type);
9696
case "bytes":
9797
case "boolean":
9898
case "plainDate":
@@ -143,7 +143,10 @@ function emitMultiPartFile(
143143
});
144144
}
145145

146-
function emitCredential(credential: SdkCredentialType): Record<string, any> {
146+
function emitCredential(
147+
context: PythonSdkContext,
148+
credential: SdkCredentialType,
149+
): Record<string, any> {
147150
let credential_type: Record<string, any> = {};
148151
const scheme = credential.scheme;
149152
if (scheme.type === "oauth2") {
@@ -152,6 +155,7 @@ function emitCredential(credential: SdkCredentialType): Record<string, any> {
152155
policy: {
153156
type: "BearerTokenCredentialPolicy",
154157
credentialScopes: [],
158+
flows: (context.emitContext.options as any).flavor === "azure" ? [] : scheme.flows,
155159
},
156160
};
157161
for (const flow of scheme.flows) {

packages/http-client-python/generator/pygen/codegen/models/credential_types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,20 @@ def __init__(
4848
yaml_data: Dict[str, Any],
4949
code_model: "CodeModel",
5050
credential_scopes: List[str],
51+
flows: Optional[Dict[str, Any]] = None,
5152
) -> None:
5253
super().__init__(yaml_data, code_model)
5354
self.credential_scopes = credential_scopes
55+
self.flows = flows
5456

5557
def call(self, async_mode: bool) -> str:
5658
policy_name = f"{'Async' if async_mode else ''}BearerTokenCredentialPolicy"
57-
return f"policies.{policy_name}(self.credential, *self.credential_scopes, **kwargs)"
59+
auth_flows = f"auth_flows={self.flows}, " if self.flows else ""
60+
return f"policies.{policy_name}(self.credential, *self.credential_scopes, {auth_flows}**kwargs)"
5861

5962
@classmethod
6063
def from_yaml(cls, yaml_data: Dict[str, Any], code_model: "CodeModel") -> "BearerTokenCredentialPolicyType":
61-
return cls(yaml_data, code_model, yaml_data["credentialScopes"])
64+
return cls(yaml_data, code_model, yaml_data["credentialScopes"], yaml_data.get("flows"))
6265

6366

6467
class ARMChallengeAuthenticationPolicyType(BearerTokenCredentialPolicyType):

packages/http-client-python/generator/test/generic_mock_api_tests/asynctests/test_authentication_async.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ class FakeCredential:
3333
async def get_token(*scopes):
3434
return core_library.credentials.AccessToken(token="".join(scopes), expires_on=1800)
3535

36+
@staticmethod
37+
async def get_token_info(*scopes, **kwargs):
38+
return core_library.credentials.AccessTokenInfo(token="".join(scopes), expires_on=1800)
39+
3640
return FakeCredential()
3741

3842

packages/http-client-python/generator/test/generic_mock_api_tests/test_authentication.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ class FakeCredential:
3333
def get_token(*scopes):
3434
return core_library.credentials.AccessToken(token="".join(scopes), expires_on=1800)
3535

36+
@staticmethod
37+
def get_token_info(*scopes, **kwargs):
38+
return core_library.credentials.AccessTokenInfo(token="".join(scopes), expires_on=1800)
39+
3640
return FakeCredential()
3741

3842

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
import pytest
7+
from authentication.oauth2.aio import OAuth2Client
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_oauth2_auth_flows():
12+
oauth2_client = OAuth2Client("fake_credential")
13+
assert oauth2_client._config.authentication_policy._auth_flows == [
14+
{
15+
"authorizationUrl": "https://login.microsoftonline.com/common/oauth2/authorize",
16+
"scopes": [{"value": "https://security.microsoft.com/.default"}],
17+
"type": "implicit",
18+
}
19+
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
from authentication.oauth2 import OAuth2Client
7+
8+
9+
def test_oauth2_auth_flows():
10+
oauth2_client = OAuth2Client("fake_credential")
11+
assert oauth2_client._config.authentication_policy._auth_flows == [
12+
{
13+
"authorizationUrl": "https://login.microsoftonline.com/common/oauth2/authorize",
14+
"scopes": [{"value": "https://security.microsoft.com/.default"}],
15+
"type": "implicit",
16+
}
17+
]

0 commit comments

Comments
 (0)