Skip to content

[#6889] CQA to support TokenCredential instead of key #6892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Bot.Builder.AI.QnA.Models;
using Microsoft.Bot.Builder.AI.QnA.Utils;
using Newtonsoft.Json;

Expand Down Expand Up @@ -48,9 +47,9 @@ public CustomQuestionAnswering(QnAMakerEndpoint endpoint, QnAMakerOptions option
throw new ArgumentException(nameof(endpoint.Host));
}

if (string.IsNullOrEmpty(endpoint.EndpointKey))
if (string.IsNullOrEmpty(endpoint.EndpointKey) && string.IsNullOrEmpty(endpoint.ManagedIdentityClientId))
{
throw new ArgumentException(nameof(endpoint.EndpointKey));
throw new ArgumentException("Either the EndpointKey or the ManagedIdentityCliendId must be provided");
}

if (_endpoint.Host.EndsWith("v2.0", StringComparison.Ordinal) || _endpoint.Host.EndsWith("v3.0", StringComparison.Ordinal))
Expand Down
147 changes: 123 additions & 24 deletions libraries/Microsoft.Bot.Builder.AI.QnA/Dialogs/QnAMakerDialog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public class QnAMakerDialog : WaterfallDialog
/// </summary>
/// <param name="dialogId">The ID of the <see cref="Dialog"/>.</param>
/// <param name="knowledgeBaseId">The ID of the QnA Maker knowledge base to query.</param>
/// <param name="endpointKey">The QnA Maker endpoint key to use to query the knowledge base.</param>
/// <param name="endpointKey">**Deprecated - use WithEndpointKey() instead**.The QnA Maker endpoint key to use to query the knowledge base.</param>
/// <param name="hostName">The QnA Maker host URL for the knowledge base, starting with "https://" and
/// ending with "/qnamaker".</param>
/// <param name="noAnswer">The activity to send the user when QnA Maker does not find an answer.</param>
Expand Down Expand Up @@ -121,36 +121,38 @@ public QnAMakerDialog(
[CallerFilePath] string sourceFilePath = "",
[CallerLineNumber] int sourceLineNumber = 0,
bool useTeamsAdaptiveCard = false)
: base(dialogId)
: this(
dialogId,
knowledgeBaseId,
hostName,
noAnswer,
threshold,
activeLearningCardTitle,
cardNoMatchText,
top,
cardNoMatchResponse,
strictFilters,
filters,
qnAServiceType,
sourceFilePath,
sourceLineNumber,
useTeamsAdaptiveCard,
httpClient)
{
this.RegisterSourceLocation(sourceFilePath, sourceLineNumber);
this.KnowledgeBaseId = knowledgeBaseId ?? throw new ArgumentNullException(nameof(knowledgeBaseId));
this.HostName = hostName ?? throw new ArgumentNullException(nameof(hostName));
this.EndpointKey = endpointKey ?? throw new ArgumentNullException(nameof(endpointKey));
this.Threshold = threshold;
this.Top = top;
this.ActiveLearningCardTitle = activeLearningCardTitle;
this.CardNoMatchText = cardNoMatchText;
this.StrictFilters = strictFilters;
this.NoAnswer = new BindToActivity(noAnswer ?? MessageFactory.Text(DefaultNoAnswer));
this.CardNoMatchResponse = new BindToActivity(cardNoMatchResponse ?? MessageFactory.Text(DefaultCardNoMatchResponse));
Filters = filters;
QnAServiceType = qnAServiceType;
this.HttpClient = httpClient;
this.UseTeamsAdaptiveCard = useTeamsAdaptiveCard;
if (!string.IsNullOrWhiteSpace(endpointKey))
{
Console.WriteLine(
"Providing an endpointKey in the QnAMakerDialog constructor is deprecated, use WithEndpointKey() method instead and provide 'null' or 'empty' value in the constructor.");

// add waterfall steps
this.AddStep(CallGenerateAnswerAsync);
this.AddStep(CallTrainAsync);
this.AddStep(CheckForMultiTurnPromptAsync);
this.AddStep(DisplayQnAResultAsync);
EndpointKey = endpointKey;
}
}

/// <summary>
/// Initializes a new instance of the <see cref="QnAMakerDialog"/> class.
/// </summary>
/// <param name="knowledgeBaseId">The ID of the QnA Maker knowledge base to query.</param>
/// <param name="endpointKey">The QnA Maker endpoint key to use to query the knowledge base.</param>
/// <param name="endpointKey">**Deprecated - use WithEndpointKey() instead**.The QnA Maker endpoint key to use to query the knowledge base.</param>
/// <param name="hostName">The QnA Maker host URL for the knowledge base, starting with "https://" and
/// ending with "/qnamaker".</param>
/// <param name="noAnswer">The activity to send the user when QnA Maker does not find an answer.</param>
Expand Down Expand Up @@ -232,6 +234,47 @@ public QnAMakerDialog([CallerFilePath] string sourceFilePath = "", [CallerLineNu
this.AddStep(DisplayQnAResultAsync);
}

internal QnAMakerDialog(
string dialogId,
string knowledgeBaseId,
string hostName,
Activity noAnswer = null,
float threshold = DefaultThreshold,
string activeLearningCardTitle = DefaultCardTitle,
string cardNoMatchText = DefaultCardNoMatchText,
int top = DefaultTopN,
Activity cardNoMatchResponse = null,
Metadata[] strictFilters = null,
Filters filters = null,
ServiceType qnAServiceType = ServiceType.QnAMaker,
[CallerFilePath] string sourceFilePath = "",
[CallerLineNumber] int sourceLineNumber = 0,
bool useTeamsAdaptiveCard = false,
HttpClient httpClient = null)
: base(dialogId)
{
RegisterSourceLocation(sourceFilePath, sourceLineNumber);
KnowledgeBaseId = knowledgeBaseId ?? throw new ArgumentNullException(nameof(knowledgeBaseId));
HostName = hostName ?? throw new ArgumentNullException(nameof(hostName));
Threshold = threshold;
Top = top;
ActiveLearningCardTitle = activeLearningCardTitle;
CardNoMatchText = cardNoMatchText;
StrictFilters = strictFilters;
NoAnswer = new BindToActivity(noAnswer ?? MessageFactory.Text(DefaultNoAnswer));
CardNoMatchResponse = new BindToActivity(cardNoMatchResponse ?? MessageFactory.Text(DefaultCardNoMatchResponse));
Filters = filters;
QnAServiceType = qnAServiceType;
HttpClient = httpClient;
UseTeamsAdaptiveCard = useTeamsAdaptiveCard;

// add waterfall steps
AddStep(CallGenerateAnswerAsync);
AddStep(CallTrainAsync);
AddStep(CheckForMultiTurnPromptAsync);
AddStep(DisplayQnAResultAsync);
}

/// <summary>
/// Gets or sets the <see cref="HttpClient"/> instance to use for requests to the QnA Maker service.
/// </summary>
Expand Down Expand Up @@ -266,6 +309,15 @@ public QnAMakerDialog([CallerFilePath] string sourceFilePath = "", [CallerLineNu
[JsonProperty("endpointKey")]
public StringExpression EndpointKey { get; set; }

/// <summary>
/// Gets or sets the ClientId of the Managed Identity resource. Access control (IAM) role `Cognitive Services User` must be assigned in the Language resource to the Managed Identity resource.
/// </summary>
/// <value>
/// The ClientId of the Managed Identity resource.
/// </value>
[JsonProperty("managedIdentityClientId")]
public StringExpression ManagedIdentityClientId { get; set; }

/// <summary>
/// Gets or sets the threshold for answers returned, based on score.
/// </summary>
Expand Down Expand Up @@ -417,6 +469,44 @@ public QnAMakerDialog([CallerFilePath] string sourceFilePath = "", [CallerLineNu
[JsonProperty("qnAServiceType")]
public EnumExpression<ServiceType> QnAServiceType { get; set; } = ServiceType.QnAMaker;

/// <summary>
/// Uses the provided QnA Maker EndpointKey to authenticate against the resource to query the knowledge base.
/// </summary>
/// <param name="endpointKey">The QnA Maker endpoint key to use to query the knowledge base.</param>
public void WithEndpointKey(string endpointKey)
{
if (string.IsNullOrWhiteSpace(endpointKey))
{
throw new ArgumentNullException(nameof(endpointKey));
}

if (ManagedIdentityClientId != null)
{
throw new ArgumentException("Cannot set EndpointKey when ManagedIdentityClientId is already set");
}

EndpointKey = endpointKey;
}

/// <summary>
/// Uses the provided QnA Maker ManagedIdentityClientId to authenticate against the resource to query the knowledge base.
/// </summary>
/// <param name="managedIdentityClientId">The QnA Maker managed identity client id to use to query the knowledge base.</param>
public void WithManagedIdentityClientId(string managedIdentityClientId)
{
if (string.IsNullOrWhiteSpace(managedIdentityClientId))
{
throw new ArgumentNullException(nameof(managedIdentityClientId));
}

if (EndpointKey != null)
{
throw new ArgumentException("Cannot set ManagedIdentityClientId when EndpointKey is already set");
}

ManagedIdentityClientId = managedIdentityClientId;
}

/// <summary>
/// Called when the dialog is started and pushed onto the dialog stack.
/// </summary>
Expand Down Expand Up @@ -532,9 +622,18 @@ protected virtual async Task<IQnAMakerClient> GetQnAMakerClientAsync(DialogConte

var httpClient = dc.Context.TurnState.Get<HttpClient>() ?? HttpClient;

var endpointKey = EndpointKey?.GetValue(dc.State);
var managedIdentityClientId = ManagedIdentityClientId?.GetValue(dc.State);

if (string.IsNullOrWhiteSpace(endpointKey) && string.IsNullOrWhiteSpace(managedIdentityClientId))
{
throw new ArgumentException("An authorization method is required. Either EndpointKey or ManagedIdentityClientId must be set");
}

var endpoint = new QnAMakerEndpoint
{
EndpointKey = this.EndpointKey.GetValue(dc.State),
EndpointKey = endpointKey,
ManagedIdentityClientId = managedIdentityClientId,
Host = this.HostName.GetValue(dc.State),
KnowledgeBaseId = KnowledgeBaseId.GetValue(dc.State),
QnAServiceType = QnAServiceType.GetValue(dc.State)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

<ItemGroup>
<!-- Force System.Text.Json to a safe version. -->
<PackageReference Include="Azure.Identity" Version="1.13.2" />
<PackageReference Include="System.Text.Json" Version="8.0.5" />
<PackageReference Include="Microsoft.Bot.Configuration" Condition=" '$(ReleasePackageVersion)' == '' " Version="$(LocalPackageVersion)" />
<PackageReference Include="Microsoft.Bot.Configuration" Condition=" '$(ReleasePackageVersion)' != '' " Version="$(ReleasePackageVersion)" />
Expand Down
9 changes: 9 additions & 0 deletions libraries/Microsoft.Bot.Builder.AI.QnA/QnAMakerEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,14 @@ public QnAMakerEndpoint(QnAMakerService service)
/// </value>
[JsonProperty("host")]
public string Host { get; set; }

/// <summary>
/// Gets or sets the ClientId of the Managed Identity resource. Access control (IAM) role `Cognitive Services User` must be assigned in the Language resource to the Managed Identity resource.
/// </summary>
/// <value>
/// The ClientId of the Managed Identity resource.
/// </value>
[JsonProperty("managedIdentityClientId")]
public string ManagedIdentityClientId { get; set; }
}
}
32 changes: 28 additions & 4 deletions libraries/Microsoft.Bot.Builder.AI.QnA/Utils/HttpRequestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Runtime.Versioning;
using System.Text;
using System.Threading.Tasks;
using Azure.Identity;

namespace Microsoft.Bot.Builder.AI.QnA
{
Expand Down Expand Up @@ -60,7 +61,7 @@ public async Task<HttpResponseMessage> ExecuteHttpRequestAsync(string requestUrl
{
request.Content = new StringContent(payloadBody, Encoding.UTF8, "application/json");

SetHeaders(request, endpoint);
await SetHeadersAsync(request, endpoint);

var response = await _httpClient.SendAsync(request).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
Expand All @@ -69,10 +70,33 @@ public async Task<HttpResponseMessage> ExecuteHttpRequestAsync(string requestUrl
}
}

private static void SetHeaders(HttpRequestMessage request, QnAMakerEndpoint endpoint)
private static async Task SetHeadersAsync(HttpRequestMessage request, QnAMakerEndpoint endpoint)
{
request.Headers.Add("Authorization", $"EndpointKey {endpoint.EndpointKey}");
request.Headers.Add("Ocp-Apim-Subscription-Key", endpoint.EndpointKey);
if (!string.IsNullOrWhiteSpace(endpoint.EndpointKey))
{
request.Headers.Add("Authorization", $"EndpointKey {endpoint.EndpointKey}");
request.Headers.Add("Ocp-Apim-Subscription-Key", endpoint.EndpointKey);
}
else if (!string.IsNullOrWhiteSpace(endpoint.ManagedIdentityClientId))
{
try
{
var client = new ManagedIdentityCredential(endpoint.ManagedIdentityClientId);
var accessToken = await client.GetTokenAsync(new Azure.Core.TokenRequestContext(["https://cognitiveservices.azure.com/.default"]));
request.Headers.Add("Authorization", $"Bearer {accessToken.Token}");
}
catch (Exception ex)
{
throw new InvalidOperationException(
$"Failed to acquire token using Managed Identity Client ID '{endpoint.ManagedIdentityClientId}'. " +
$"Ensure the Managed Identity exists and has the 'Cognitive Services User' role assigned.", ex);
}
}
else
{
throw new ArgumentNullException(nameof(endpoint), "Either EndpointKey or ManagedIdentityClientId must be provided.");
}

request.Headers.UserAgent.Add(botBuilderInfo);
request.Headers.UserAgent.Add(platformInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2220,7 +2220,7 @@ public class LanguageServiceTestDialog : ComponentDialog, IDialogDependencies
public LanguageServiceTestDialog(string knowledgeBaseId, string endpointKey, string hostName, HttpClient httpClient)
: base(nameof(LanguageServiceTestDialog))
{
AddDialog(new QnAMakerDialog(knowledgeBaseId, endpointKey, hostName, httpClient: httpClient));
AddDialog(new QnAMakerDialog(knowledgeBaseId, endpointKey: endpointKey, hostName, httpClient: httpClient));
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion tests/Microsoft.Bot.Builder.AI.QnA.Tests/QnAMakerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ public class QnaMakerTestDialog : ComponentDialog, IDialogDependencies
public QnaMakerTestDialog(string knowledgeBaseId, string endpointKey, string hostName, HttpClient httpClient)
: base(nameof(QnaMakerTestDialog))
{
AddDialog(new QnAMakerDialog(knowledgeBaseId, endpointKey, hostName, httpClient: httpClient));
AddDialog(new QnAMakerDialog(knowledgeBaseId, endpointKey: endpointKey, hostName, httpClient: httpClient));
}

/// <summary>
Expand Down
Loading