Skip to content

Commit 6100a72

Browse files
stephentoubjeffhandley
authored andcommitted
Move AIFunctionFactory down to M.E.AI.Abstractions (#6412)
* Remove AIFunctionFactory dependency on M.E.DI This means reverting the recent changes to it that: - Special-cased KeyedServices - Special-cased IServiceProviderIsService - Used ActivatorUtilities.CreateInstance * Move AIFunctionFactory down to M.E.AI.Abstractions * Add CreateInstance delegate to AIFunctionFactoryOptions To enable use of ActivatorUtilities.CreateInstance or alternative. * Add some comments
1 parent 53e2b53 commit 6100a72

File tree

3 files changed

+94
-66
lines changed

3 files changed

+94
-66
lines changed

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactory.cs

Lines changed: 37 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Collections.Generic;
88
using System.ComponentModel;
99
using System.Diagnostics;
10+
using System.Diagnostics.CodeAnalysis;
1011
using System.IO;
1112
#if !NET
1213
using System.Linq;
@@ -25,7 +26,6 @@
2526
#pragma warning disable CA1031 // Do not catch general exception types
2627
#pragma warning disable S2333 // Redundant modifiers should not be used
2728
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
28-
#pragma warning disable SA1202 // Public members should come before private members
2929

3030
namespace Microsoft.Extensions.AI;
3131

@@ -373,15 +373,16 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
373373
}
374374

375375
/// <summary>
376-
/// Creates an <see cref="AIFunction"/> instance for a method, specified via a <see cref="MethodInfo"/> for
377-
/// an instance method and a <see cref="Func{AIFunctionArguments,Object}"/> for constructing an instance of
378-
/// the receiver object each time the <see cref="AIFunction"/> is invoked.
376+
/// Creates an <see cref="AIFunction"/> instance for a method, specified via an <see cref="MethodInfo"/> for
377+
/// and instance method, along with a <see cref="Type"/> representing the type of the target object to
378+
/// instantiate each time the method is invoked.
379379
/// </summary>
380380
/// <param name="method">The instance method to be represented via the created <see cref="AIFunction"/>.</param>
381-
/// <param name="createInstanceFunc">
382-
/// Callback used on each function invocation to create an instance of the type on which the instance method <paramref name="method"/>
383-
/// will be invoked. If the returned instance is <see cref="IAsyncDisposable"/> or <see cref="IDisposable"/>, it will be disposed of
384-
/// after <paramref name="method"/> completes its invocation.
381+
/// <param name="targetType">
382+
/// The <see cref="Type"/> to construct an instance of on which to invoke <paramref name="method"/> when
383+
/// the resulting <see cref="AIFunction"/> is invoked. <see cref="Activator.CreateInstance(Type)"/> is used,
384+
/// utilizing the type's public parameterless constructor. If an instance can't be constructed, an exception is
385+
/// thrown during the function's invocation.
385386
/// </param>
386387
/// <param name="options">Metadata to use to override defaults inferred from <paramref name="method"/>.</param>
387388
/// <returns>The created <see cref="AIFunction"/> for invoking <paramref name="method"/>.</returns>
@@ -455,16 +456,22 @@ public static AIFunction Create(MethodInfo method, object? target, string? name
455456
/// </para>
456457
/// </remarks>
457458
/// <exception cref="ArgumentNullException"><paramref name="method"/> is <see langword="null"/>.</exception>
458-
/// <exception cref="ArgumentNullException"><paramref name="createInstanceFunc"/> is <see langword="null"/>.</exception>
459+
/// <exception cref="ArgumentNullException"><paramref name="targetType"/> is <see langword="null"/>.</exception>
459460
/// <exception cref="ArgumentException"><paramref name="method"/> represents a static method.</exception>
460461
/// <exception cref="ArgumentException"><paramref name="method"/> represents an open generic method.</exception>
461462
/// <exception cref="ArgumentException"><paramref name="method"/> contains a parameter without a parameter name.</exception>
463+
/// <exception cref="ArgumentException"><paramref name="targetType"/> is not assignable to <paramref name="method"/>'s declaring type.</exception>
462464
/// <exception cref="JsonException">A parameter to <paramref name="method"/> or its return type is not serializable.</exception>
463465
public static AIFunction Create(
464466
MethodInfo method,
465-
Func<AIFunctionArguments, object> createInstanceFunc,
466-
AIFunctionFactoryOptions? options = null) =>
467-
ReflectionAIFunction.Build(method, createInstanceFunc, options ?? _defaultOptions);
467+
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType,
468+
AIFunctionFactoryOptions? options = null)
469+
{
470+
_ = Throw.IfNull(method);
471+
_ = Throw.IfNull(targetType);
472+
473+
return ReflectionAIFunction.Build(method, targetType, options ?? _defaultOptions);
474+
}
468475

469476
private sealed class ReflectionAIFunction : AIFunction
470477
{
@@ -495,11 +502,10 @@ public static ReflectionAIFunction Build(MethodInfo method, object? target, AIFu
495502

496503
public static ReflectionAIFunction Build(
497504
MethodInfo method,
498-
Func<AIFunctionArguments, object> createInstanceFunc,
505+
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType,
499506
AIFunctionFactoryOptions options)
500507
{
501508
_ = Throw.IfNull(method);
502-
_ = Throw.IfNull(createInstanceFunc);
503509

504510
if (method.ContainsGenericParameters)
505511
{
@@ -511,7 +517,13 @@ public static ReflectionAIFunction Build(
511517
Throw.ArgumentException(nameof(method), "The method must be an instance method.");
512518
}
513519

514-
return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), createInstanceFunc, options);
520+
if (method.DeclaringType is { } declaringType &&
521+
!declaringType.IsAssignableFrom(targetType))
522+
{
523+
Throw.ArgumentException(nameof(targetType), "The target type must be assignable to the method's declaring type.");
524+
}
525+
526+
return new(ReflectionAIFunctionDescriptor.GetOrCreate(method, options), targetType, options);
515527
}
516528

517529
private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor, object? target, AIFunctionFactoryOptions options)
@@ -523,17 +535,20 @@ private ReflectionAIFunction(ReflectionAIFunctionDescriptor functionDescriptor,
523535

524536
private ReflectionAIFunction(
525537
ReflectionAIFunctionDescriptor functionDescriptor,
526-
Func<AIFunctionArguments, object> createInstanceFunc,
538+
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType,
527539
AIFunctionFactoryOptions options)
528540
{
529541
FunctionDescriptor = functionDescriptor;
530-
CreateInstanceFunc = createInstanceFunc;
542+
TargetType = targetType;
543+
CreateInstance = options.CreateInstance;
531544
AdditionalProperties = options.AdditionalProperties ?? EmptyReadOnlyDictionary<string, object?>.Instance;
532545
}
533546

534547
public ReflectionAIFunctionDescriptor FunctionDescriptor { get; }
535548
public object? Target { get; }
536-
public Func<AIFunctionArguments, object>? CreateInstanceFunc { get; }
549+
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)]
550+
public Type? TargetType { get; }
551+
public Func<Type, AIFunctionArguments, object>? CreateInstance { get; }
537552

538553
public override IReadOnlyDictionary<string, object?> AdditionalProperties { get; }
539554
public override string Name => FunctionDescriptor.Name;
@@ -550,12 +565,14 @@ private ReflectionAIFunction(
550565
object? target = Target;
551566
try
552567
{
553-
if (CreateInstanceFunc is { } func)
568+
if (TargetType is { } targetType)
554569
{
555570
Debug.Assert(target is null, "Expected target to be null when we have a non-null target type");
556571
Debug.Assert(!FunctionDescriptor.Method.IsStatic, "Expected an instance method");
557572

558-
target = func(arguments);
573+
target = CreateInstance is not null ?
574+
CreateInstance(targetType, arguments) :
575+
Activator.CreateInstance(targetType);
559576
if (target is null)
560577
{
561578
Throw.InvalidOperationException("Unable to create an instance of the target type.");
@@ -1088,34 +1105,6 @@ public override void Flush()
10881105
{
10891106
}
10901107

1091-
public override Task FlushAsync(CancellationToken cancellationToken) =>
1092-
Task.CompletedTask;
1093-
1094-
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
1095-
WriteAsync(new ReadOnlyMemory<byte>(buffer, offset, count), cancellationToken).AsTask();
1096-
1097-
#if NET
1098-
public override
1099-
#else
1100-
private
1101-
#endif
1102-
ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
1103-
{
1104-
EnsureNotDisposed();
1105-
1106-
if (cancellationToken.IsCancellationRequested)
1107-
{
1108-
return new ValueTask(Task.FromCanceled(cancellationToken));
1109-
}
1110-
1111-
EnsureCapacity(_position + buffer.Length);
1112-
1113-
buffer.Span.CopyTo(_buffer.AsSpan(_position));
1114-
_position += buffer.Length;
1115-
1116-
return default;
1117-
}
1118-
11191108
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
11201109
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
11211110
public override void SetLength(long value) => throw new NotSupportedException();

src/Libraries/Microsoft.Extensions.AI.Abstractions/Functions/AIFunctionFactoryOptions.cs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,24 @@ public AIFunctionFactoryOptions()
106106
/// </remarks>
107107
public Func<object?, Type?, CancellationToken, ValueTask<object?>>? MarshalResult { get; set; }
108108

109+
/// <summary>
110+
/// Gets or sets a delegate used with <see cref="AIFunctionFactory.Create(MethodInfo, Type, AIFunctionFactoryOptions?)"/> to create the receiver instance.
111+
/// </summary>
112+
/// <remarks>
113+
/// <para>
114+
/// <see cref="AIFunctionFactory.Create(MethodInfo, Type, AIFunctionFactoryOptions?)"/> creates <see cref="AIFunction"/> instances that invoke an
115+
/// instance method on the specified <see cref="Type"/>. This delegate is used to create the instance of the type that will be used to invoke the method.
116+
/// By default if <see cref="CreateInstance"/> is <see langword="null"/>, <see cref="Activator.CreateInstance(Type)"/> is used. If
117+
/// <see cref="CreateInstance"/> is non-<see langword="null"/>, the delegate is invoked with the <see cref="Type"/> to be instantiated and the
118+
/// <see cref="AIFunctionArguments"/> provided to the <see cref="AIFunction.InvokeAsync"/> method.
119+
/// </para>
120+
/// <para>
121+
/// Each created instance will be used for a single invocation. If the object is <see cref="IAsyncDisposable"/> or <see cref="IDisposable"/>, it will
122+
/// be disposed of after the invocation completes.
123+
/// </para>
124+
/// </remarks>
125+
public Func<Type, AIFunctionArguments, object>? CreateInstance { get; set; }
126+
109127
/// <summary>Provides configuration options produced by the <see cref="ConfigureParameterBinding"/> delegate.</summary>
110128
public readonly record struct ParameterBindingOptions
111129
{

test/Libraries/Microsoft.Extensions.AI.Tests/Functions/AIFunctionFactoryTest.cs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ public void InvalidArguments_Throw()
2929
Assert.Throws<ArgumentNullException>("method", () => AIFunctionFactory.Create(method: null!, target: new object()));
3030
Assert.Throws<ArgumentNullException>("method", () => AIFunctionFactory.Create(method: null!, target: new object(), name: "myAiFunk"));
3131
Assert.Throws<ArgumentNullException>("target", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (object?)null));
32-
Assert.Throws<ArgumentNullException>("createInstanceFunc", () =>
33-
AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (Func<AIFunctionArguments, object>)null!));
32+
Assert.Throws<ArgumentNullException>("targetType", () => AIFunctionFactory.Create(typeof(AIFunctionFactoryTest).GetMethod(nameof(InvalidArguments_Throw))!, (Type)null!));
3433
Assert.Throws<ArgumentException>("method", () => AIFunctionFactory.Create(typeof(List<>).GetMethod("Add")!, new List<int>()));
3534
}
3635

@@ -313,12 +312,16 @@ public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable(
313312

314313
AIFunction func = AIFunctionFactory.Create(
315314
typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!,
316-
static arguments =>
315+
typeof(MyFunctionTypeWithOneArg),
316+
new()
317317
{
318-
Assert.NotNull(arguments.Services);
319-
return ActivatorUtilities.CreateInstance(arguments.Services, typeof(MyFunctionTypeWithOneArg));
320-
},
321-
new() { MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result) });
318+
CreateInstance = (type, arguments) =>
319+
{
320+
Assert.NotNull(arguments.Services);
321+
return ActivatorUtilities.CreateInstance(arguments.Services, type);
322+
},
323+
MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
324+
});
322325

323326
Assert.NotNull(func);
324327
var result = (Tuple<MyFunctionTypeWithOneArg, MyArgumentType>?)await func.InvokeAsync(new() { Services = sp });
@@ -327,41 +330,55 @@ public async Task Create_NoInstance_UsesActivatorUtilitiesWhenServicesAvailable(
327330
}
328331

329332
[Fact]
330-
public async Task Create_CreateInstanceReturnsNull_ThrowsDuringInvocation()
333+
public async Task Create_NoInstance_UsesActivatorWhenServicesUnavailable()
331334
{
332335
AIFunction func = AIFunctionFactory.Create(
333-
typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!,
334-
static _ => null!);
336+
typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!,
337+
typeof(MyFunctionTypeWithNoArgs),
338+
new()
339+
{
340+
MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
341+
});
335342

336343
Assert.NotNull(func);
337-
await Assert.ThrowsAsync<InvalidOperationException>(async () => await func.InvokeAsync());
344+
Assert.Equal("42", await func.InvokeAsync());
338345
}
339346

340347
[Fact]
341-
public async Task Create_WrongConstructedType_ThrowsDuringInvocation()
348+
public async Task Create_NoInstance_ThrowsWhenCantConstructInstance()
342349
{
350+
var sp = new ServiceCollection().BuildServiceProvider();
351+
343352
AIFunction func = AIFunctionFactory.Create(
344353
typeof(MyFunctionTypeWithOneArg).GetMethod(nameof(MyFunctionTypeWithOneArg.InstanceMethod))!,
345-
static _ => new MyFunctionTypeWithNoArgs());
354+
typeof(MyFunctionTypeWithOneArg));
346355

347356
Assert.NotNull(func);
348-
await Assert.ThrowsAsync<TargetException>(async () => await func.InvokeAsync());
357+
await Assert.ThrowsAsync<MissingMethodException>(async () => await func.InvokeAsync(new() { Services = sp }));
349358
}
350359

351360
[Fact]
352361
public void Create_NoInstance_ThrowsForStaticMethod()
353362
{
354363
Assert.Throws<ArgumentException>("method", () => AIFunctionFactory.Create(
355364
typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.StaticMethod))!,
356-
static _ => new MyFunctionTypeWithNoArgs()));
365+
typeof(MyFunctionTypeWithNoArgs)));
366+
}
367+
368+
[Fact]
369+
public void Create_NoInstance_ThrowsForMismatchedMethod()
370+
{
371+
Assert.Throws<ArgumentException>("targetType", () => AIFunctionFactory.Create(
372+
typeof(MyFunctionTypeWithNoArgs).GetMethod(nameof(MyFunctionTypeWithNoArgs.InstanceMethod))!,
373+
typeof(MyFunctionTypeWithOneArg)));
357374
}
358375

359376
[Fact]
360377
public async Task Create_NoInstance_DisposableInstanceCreatedDisposedEachInvocation()
361378
{
362379
AIFunction func = AIFunctionFactory.Create(
363380
typeof(DisposableService).GetMethod(nameof(DisposableService.GetThis))!,
364-
static _ => new DisposableService(),
381+
typeof(DisposableService),
365382
new()
366383
{
367384
MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
@@ -380,7 +397,7 @@ public async Task Create_NoInstance_AsyncDisposableInstanceCreatedDisposedEachIn
380397
{
381398
AIFunction func = AIFunctionFactory.Create(
382399
typeof(AsyncDisposableService).GetMethod(nameof(AsyncDisposableService.GetThis))!,
383-
static _ => new AsyncDisposableService(),
400+
typeof(AsyncDisposableService),
384401
new()
385402
{
386403
MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
@@ -399,7 +416,7 @@ public async Task Create_NoInstance_DisposableAndAsyncDisposableInstanceCreatedD
399416
{
400417
AIFunction func = AIFunctionFactory.Create(
401418
typeof(DisposableAndAsyncDisposableService).GetMethod(nameof(DisposableAndAsyncDisposableService.GetThis))!,
402-
static _ => new DisposableAndAsyncDisposableService(),
419+
typeof(DisposableAndAsyncDisposableService),
403420
new()
404421
{
405422
MarshalResult = (result, type, cancellationToken) => new ValueTask<object?>(result),
@@ -804,7 +821,11 @@ public ValueTask DisposeAsync()
804821

805822
private sealed class MyFunctionTypeWithNoArgs
806823
{
824+
private string _value = "42";
825+
807826
public static void StaticMethod() => throw new NotSupportedException();
827+
828+
public string InstanceMethod() => _value;
808829
}
809830

810831
private sealed class MyFunctionTypeWithOneArg(MyArgumentType arg)

0 commit comments

Comments
 (0)