Skip to content

Commit 19b1278

Browse files
authored
Move AIFunction members down to AITool and add CodeInterpreterTool (#5898)
Various AI services (Gemini, Bedrock, OpenAI, etc.) now support server-side code interpreting, where the model can generate code it can then execute in a sandbox. For good reason, they all require opting in. We can model that well as an AITool-derived type.
1 parent a03ee5b commit 19b1278

File tree

9 files changed

+140
-44
lines changed

9 files changed

+140
-44
lines changed
Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,55 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Collections.Generic;
5+
using System.Diagnostics;
6+
using System.Text;
7+
using Microsoft.Shared.Collections;
8+
49
namespace Microsoft.Extensions.AI;
510

11+
#pragma warning disable S1694 // An abstract class should have both abstract and concrete methods
12+
613
/// <summary>Represents a tool that can be specified to an AI service.</summary>
7-
public class AITool
14+
[DebuggerDisplay("{DebuggerDisplay,nq}")]
15+
public abstract class AITool
816
{
917
/// <summary>Initializes a new instance of the <see cref="AITool"/> class.</summary>
1018
protected AITool()
1119
{
1220
}
21+
22+
/// <summary>Gets the name of the tool.</summary>
23+
public virtual string Name => GetType().Name;
24+
25+
/// <summary>Gets a description of the tool, suitable for use in describing the purpose to a model.</summary>
26+
public virtual string Description => string.Empty;
27+
28+
/// <summary>Gets any additional properties associated with the tool.</summary>
29+
public virtual IReadOnlyDictionary<string, object?> AdditionalProperties => EmptyReadOnlyDictionary<string, object?>.Instance;
30+
31+
/// <inheritdoc/>
32+
public override string ToString() => Name;
33+
34+
/// <summary>Gets the string to display in the debugger for this instance.</summary>
35+
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
36+
private string DebuggerDisplay
37+
{
38+
get
39+
{
40+
StringBuilder sb = new(Name);
41+
42+
if (Description is string description && !string.IsNullOrEmpty(description))
43+
{
44+
_ = sb.Append(" (").Append(description).Append(')');
45+
}
46+
47+
foreach (var entry in AdditionalProperties)
48+
{
49+
_ = sb.Append(", ").Append(entry.Key).Append(" = ").Append(entry.Value);
50+
}
51+
52+
return sb.ToString();
53+
}
54+
}
1355
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
namespace Microsoft.Extensions.AI;
5+
6+
/// <summary>Represents a tool that can be specified to an AI service to enable it to execute code it generates.</summary>
7+
/// <remarks>
8+
/// This tool does not itself implement code interpration. It is a marker that can be used to inform a service
9+
/// that the service is allowed to execute its generated code if the service is capable of doing so.
10+
/// </remarks>
11+
public class CodeInterpreterTool : AITool
12+
{
13+
/// <summary>Initializes a new instance of the <see cref="CodeInterpreterTool"/> class.</summary>
14+
public CodeInterpreterTool()
15+
{
16+
}
17+
}

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunction.cs

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Generic;
5-
using System.Diagnostics;
65
using System.Reflection;
76
using System.Text.Json;
87
using System.Threading;
@@ -12,15 +11,8 @@
1211
namespace Microsoft.Extensions.AI;
1312

1413
/// <summary>Represents a function that can be described to an AI service and invoked.</summary>
15-
[DebuggerDisplay("{DebuggerDisplay,nq}")]
1614
public abstract class AIFunction : AITool
1715
{
18-
/// <summary>Gets the name of the function.</summary>
19-
public abstract string Name { get; }
20-
21-
/// <summary>Gets a description of the function, suitable for use in describing the purpose to a model.</summary>
22-
public abstract string Description { get; }
23-
2416
/// <summary>Gets a JSON Schema describing the function and its input parameters.</summary>
2517
/// <remarks>
2618
/// <para>
@@ -56,11 +48,8 @@ public abstract class AIFunction : AITool
5648
/// </remarks>
5749
public virtual MethodInfo? UnderlyingMethod => null;
5850

59-
/// <summary>Gets any additional properties associated with the function.</summary>
60-
public virtual IReadOnlyDictionary<string, object?> AdditionalProperties => EmptyReadOnlyDictionary<string, object?>.Instance;
61-
6251
/// <summary>Gets a <see cref="JsonSerializerOptions"/> that can be used to marshal function parameters.</summary>
63-
public virtual JsonSerializerOptions? JsonSerializerOptions => AIJsonUtilities.DefaultOptions;
52+
public virtual JsonSerializerOptions JsonSerializerOptions => AIJsonUtilities.DefaultOptions;
6453

6554
/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
6655
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
@@ -75,18 +64,11 @@ public abstract class AIFunction : AITool
7564
return InvokeCoreAsync(arguments, cancellationToken);
7665
}
7766

78-
/// <inheritdoc/>
79-
public override string ToString() => Name;
80-
8167
/// <summary>Invokes the <see cref="AIFunction"/> and returns its result.</summary>
8268
/// <param name="arguments">The arguments to pass to the function's invocation.</param>
8369
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests.</param>
8470
/// <returns>The result of the function's execution.</returns>
8571
protected abstract Task<object?> InvokeCoreAsync(
8672
IEnumerable<KeyValuePair<string, object?>> arguments,
8773
CancellationToken cancellationToken);
88-
89-
/// <summary>Gets the string to display in the debugger for this instance.</summary>
90-
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
91-
private string DebuggerDisplay => string.IsNullOrWhiteSpace(Description) ? Name : $"{Name} ({Description})";
9274
}

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIAssistantClient.cs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,25 @@ private static (RunCreationOptions RunOptions, List<FunctionResultContent>? Tool
212212
{
213213
foreach (AITool tool in tools)
214214
{
215-
if (tool is AIFunction aiFunction)
215+
switch (tool)
216216
{
217-
bool? strict =
218-
aiFunction.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
219-
strictObj is bool strictValue ?
220-
strictValue : null;
221-
222-
var functionParameters = BinaryData.FromBytes(
223-
JsonSerializer.SerializeToUtf8Bytes(
224-
JsonSerializer.Deserialize(aiFunction.JsonSchema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
225-
OpenAIJsonContext.Default.OpenAIChatToolJson));
226-
227-
runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Name, aiFunction.Description, functionParameters, strict));
217+
case AIFunction aiFunction:
218+
bool? strict =
219+
aiFunction.AdditionalProperties.TryGetValue("Strict", out object? strictObj) &&
220+
strictObj is bool strictValue ?
221+
strictValue : null;
222+
223+
var functionParameters = BinaryData.FromBytes(
224+
JsonSerializer.SerializeToUtf8Bytes(
225+
JsonSerializer.Deserialize(aiFunction.JsonSchema, OpenAIJsonContext.Default.OpenAIChatToolJson)!,
226+
OpenAIJsonContext.Default.OpenAIChatToolJson));
227+
228+
runOptions.ToolsOverride.Add(ToolDefinition.CreateFunction(aiFunction.Name, aiFunction.Description, functionParameters, strict));
229+
break;
230+
231+
case CodeInterpreterTool:
232+
runOptions.ToolsOverride.Add(ToolDefinition.CreateCodeInterpreter());
233+
break;
228234
}
229235
}
230236
}

src/Libraries/Microsoft.Extensions.AI.OpenAI/OpenAIModelMapper.ChatCompletion.cs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,10 @@ public static ChatOptions FromOpenAIOptions(ChatCompletionOptions? options)
240240
{
241241
foreach (ChatTool tool in tools)
242242
{
243-
result.Tools ??= [];
244-
result.Tools.Add(FromOpenAIChatTool(tool));
243+
if (FromOpenAIChatTool(tool) is { } convertedTool)
244+
{
245+
(result.Tools ??= []).Add(convertedTool);
246+
}
245247
}
246248

247249
using var toolChoiceJson = JsonDocument.Parse(JsonModelHelpers.Serialize(options.ToolChoice).ToMemory());
@@ -407,17 +409,24 @@ public static ChatCompletionOptions ToOpenAIOptions(ChatOptions? options)
407409
return result;
408410
}
409411

410-
private static AITool FromOpenAIChatTool(ChatTool chatTool)
412+
private static AITool? FromOpenAIChatTool(ChatTool chatTool)
411413
{
412-
AdditionalPropertiesDictionary additionalProperties = [];
413-
if (chatTool.FunctionSchemaIsStrict is bool strictValue)
414+
switch (chatTool.Kind)
414415
{
415-
additionalProperties["Strict"] = strictValue;
416-
}
416+
case ChatToolKind.Function:
417+
AdditionalPropertiesDictionary additionalProperties = [];
418+
if (chatTool.FunctionSchemaIsStrict is bool strictValue)
419+
{
420+
additionalProperties["Strict"] = strictValue;
421+
}
422+
423+
OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!;
424+
JsonElement schema = JsonSerializer.SerializeToElement(openAiChatTool, OpenAIJsonContext.Default.OpenAIChatToolJson);
425+
return new MetadataOnlyAIFunction(chatTool.FunctionName, chatTool.FunctionDescription, schema, additionalProperties);
417426

418-
OpenAIChatToolJson openAiChatTool = JsonSerializer.Deserialize(chatTool.FunctionParameters.ToMemory().Span, OpenAIJsonContext.Default.OpenAIChatToolJson)!;
419-
JsonElement schema = JsonSerializer.SerializeToElement(openAiChatTool, OpenAIJsonContext.Default.OpenAIChatToolJson);
420-
return new MetadataOnlyAIFunction(chatTool.FunctionName, chatTool.FunctionDescription, schema, additionalProperties);
427+
default:
428+
return null;
429+
}
421430
}
422431

423432
private sealed class MetadataOnlyAIFunction(string name, string description, JsonElement schema, IReadOnlyDictionary<string, object?> additionalProps) : AIFunction

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ static bool IsAsyncMethod(MethodInfo method)
340340
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
341341

342342
// Create a marshaller that simply looks up the parameter by name in the arguments dictionary.
343-
return (IReadOnlyDictionary<string, object?> arguments, AIFunctionContext? _) =>
343+
return (arguments, _) =>
344344
{
345345
// If the parameter has an argument specified in the dictionary, return that argument.
346346
if (arguments.TryGetValue(parameter.Name, out object? value))

src/Libraries/Microsoft.Extensions.AI/Functions/AIFunctionFactoryOptions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public AIFunctionFactoryOptions()
4949
public string? Description { get; set; }
5050

5151
/// <summary>
52-
/// Gets or sets additional values to store on the resulting <see cref="AIFunction.AdditionalProperties" /> property.
52+
/// Gets or sets additional values to store on the resulting <see cref="AITool.AdditionalProperties" /> property.
5353
/// </summary>
5454
/// <remarks>
5555
/// This property can be used to provide arbitrary information about the function.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Xunit;
5+
6+
namespace Microsoft.Extensions.AI;
7+
8+
public class AIToolTests
9+
{
10+
[Fact]
11+
public void Constructor_Roundtrips()
12+
{
13+
DerivedAITool tool = new();
14+
Assert.Equal(nameof(DerivedAITool), tool.Name);
15+
Assert.Equal(nameof(DerivedAITool), tool.ToString());
16+
Assert.Empty(tool.Description);
17+
Assert.Empty(tool.AdditionalProperties);
18+
}
19+
20+
private sealed class DerivedAITool : AITool;
21+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Xunit;
5+
6+
namespace Microsoft.Extensions.AI;
7+
8+
public class CodeInterpreterToolTests
9+
{
10+
[Fact]
11+
public void Constructor_Roundtrips()
12+
{
13+
var tool = new CodeInterpreterTool();
14+
Assert.Equal(nameof(CodeInterpreterTool), tool.Name);
15+
Assert.Empty(tool.Description);
16+
Assert.Empty(tool.AdditionalProperties);
17+
Assert.Equal(nameof(CodeInterpreterTool), tool.ToString());
18+
}
19+
}

0 commit comments

Comments
 (0)