Skip to content

Fix issues in GetKeyedService() and GetKeyedServices() with AnyKey #113137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 14, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,84 @@ public abstract partial class KeyedDependencyInjectionSpecificationTests
{
protected abstract IServiceProvider CreateServiceProvider(IServiceCollection collection);

[Fact]
public void CombinationalRegistration()
{
Service service1 = new();
Service service2 = new();
Service keyedService1 = new();
Service keyedService2 = new();
Service anykeyService1 = new();
Service anykeyService2 = new();
Service nullkeyService1 = new();
Service nullkeyService2 = new();

ServiceCollection serviceCollection = new();
serviceCollection.AddSingleton<IService>(service1);
serviceCollection.AddSingleton<IService>(service2);
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService1);
serviceCollection.AddKeyedSingleton<IService>(null, nullkeyService2);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService1);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, anykeyService2);
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService1);
serviceCollection.AddKeyedSingleton<IService>("keyedService", keyedService2);

IServiceProvider provider = CreateServiceProvider(serviceCollection);

/*
* Table for what results are included:
*
* Query | Keyed? | Unkeyed? | AnyKey? | null key?
* -------------------------------------------------------------------
* GetServices(Type) | no | yes | no | yes
* GetService(Type) | no | yes | no | yes
*
* GetKeyedServices(null) | no | yes | no | yes
* GetKeyedService(null) | no | yes | no | yes
*
* GetKeyedServices(AnyKey) | yes | no | no | no
* GetKeyedService(AnyKey) | throw | throw | throw | throw
*
* GetKeyedServices(key) | yes | no | no | no
* GetKeyedService(key) | yes | no | yes | no
*
* Summary:
* - A null key is the same as unkeyed. This allows the KeyServices APIs to support both keyed and unkeyed.
* - AnyKey is a special case of Keyed.
* - AnyKey registrations are not returned with GetKeyedServices(AnyKey) and GetKeyedService(AnyKey) always throws.
* - For IEnumerable, the ordering of the results are in registration order.
* - For a singleton resolve, the last match wins.
*/

// Unkeyed (which is really keyed by Type).
Assert.Equal(
new[] { service1, service2, nullkeyService1, nullkeyService2 },
provider.GetServices<IService>());

Assert.Equal(nullkeyService2, provider.GetService<IService>());

// Null key.
Assert.Equal(
new[] { service1, service2, nullkeyService1, nullkeyService2 },
provider.GetKeyedServices<IService>(null));

Assert.Equal(nullkeyService2, provider.GetKeyedService<IService>(null));

// AnyKey.
Assert.Equal(
new[] { keyedService1, keyedService2 },
provider.GetKeyedServices<IService>(KeyedService.AnyKey));

Assert.Throws<InvalidOperationException>(() => provider.GetKeyedService<IService>(KeyedService.AnyKey));

// Keyed.
Assert.Equal(
new[] { keyedService1, keyedService2 },
provider.GetKeyedServices<IService>("keyedService"));

Assert.Equal(keyedService2, provider.GetKeyedService<IService>("keyedService"));
}

[Fact]
public void ResolveKeyedService()
{
Expand Down Expand Up @@ -158,10 +236,75 @@ public void ResolveKeyedServicesAnyKeyWithAnyKeyRegistration()
_ = provider.GetKeyedService<IService>("something-else");
_ = provider.GetKeyedService<IService>("something-else-again");

// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey
// Return all services registered with a non null key, but not the one "created" with KeyedService.AnyKey,
// nor the KeyedService.AnyKey registration
var allServices = provider.GetKeyedServices<IService>(KeyedService.AnyKey).ToList();
Assert.Equal(5, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices.Skip(1));
Assert.Equal(4, allServices.Count);
Assert.Equal(new[] { service1, service2, service3, service4 }, allServices);

var someKeyedServices = provider.GetKeyedServices<IService>("service").ToList();
Assert.Equal(new[] { service2, service3, service4 }, someKeyedServices);

var unkeyedServices = provider.GetServices<IService>().ToList();
Assert.Equal(new[] { service5, service6 }, unkeyedServices);
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistency()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Throws<InvalidOperationException>(() => provider1.GetKeyedService<IService>(KeyedService.AnyKey));
// We don't return KeyedService.AnyKey registration when listing services
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
public void ResolveKeyedServicesAnyKeyConsistencyWithAnyKeyRegistration()
{
var serviceCollection = new ServiceCollection();
var service = new Service("first-service");
var any = new Service("any");
serviceCollection.AddKeyedSingleton<IService>("first-service", service);
serviceCollection.AddKeyedSingleton<IService>(KeyedService.AnyKey, (sp, key) => any);

var provider1 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider1.GetKeyedServices<IService>(KeyedService.AnyKey));

// Check twice in different order to check caching
var provider2 = CreateServiceProvider(serviceCollection);
Assert.Equal(new[] { service }, provider2.GetKeyedServices<IService>(KeyedService.AnyKey));
Assert.Same(any, provider2.GetKeyedService<IService>(new object()));

Assert.Throws<InvalidOperationException>(() => provider2.GetKeyedService<IService>(KeyedService.AnyKey));
}

[Fact]
public void ResolveKeyedServicesAnyKeyOrdering()
{
var serviceCollection = new ServiceCollection();
var service1 = new Service();
var service2 = new Service();
var service3 = new Service();

serviceCollection.AddKeyedSingleton<IService>("A-service", service1);
serviceCollection.AddKeyedSingleton<IService>("B-service", service2);
serviceCollection.AddKeyedSingleton<IService>("A-service", service3);

var provider = CreateServiceProvider(serviceCollection);

// The order should be in registration order, and not grouped by key for example.
// Although this isn't necessarily a requirement, it is the current behavior.
Assert.Equal(
new[] { service1, service2, service3 },
provider.GetKeyedServices<IService>(KeyedService.AnyKey));
}

[Fact]
Expand Down Expand Up @@ -250,7 +393,7 @@ public void ResolveKeyedServicesSingletonInstanceWithAnyKey()
var provider = CreateServiceProvider(serviceCollection);

var services = provider.GetKeyedServices<IFakeOpenGenericService<PocoClass>>("some-key").ToList();
Assert.Equal(new[] { service1, service2 }, services);
Assert.Equal(new[] { service2 }, services);
}

[Fact]
Expand Down Expand Up @@ -504,6 +647,9 @@ public void ResolveKeyedSingletonFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand All @@ -528,6 +674,9 @@ public void ResolveKeyedScopedFromScopeServiceProvider()
Assert.Null(scopeA.ServiceProvider.GetService<IService>());
Assert.Null(scopeB.ServiceProvider.GetService<IService>());

Assert.Throws<InvalidOperationException>(() => scopeA.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));
Assert.Throws<InvalidOperationException>(() => scopeB.ServiceProvider.GetKeyedService<IService>(KeyedService.AnyKey));

var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,7 @@
<data name="InvalidServiceKeyType" xml:space="preserve">
<value>The type of the key used for lookup doesn't match the type in the constructor parameter with the ServiceKey attribute.</value>
</data>
<data name="KeyedServiceAnyKeyUsedToResolveService" xml:space="preserve">
<value>KeyedService.AnyKey cannot be used to resolve a single service.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,13 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
CallSiteResultCacheLocation cacheLocation = CallSiteResultCacheLocation.Root;
ServiceCallSite[] callSites;

var isAnyKeyLookup = serviceIdentifier.ServiceKey == KeyedService.AnyKey;

// If item type is not generic we can safely use descriptor cache
// Special case for KeyedService.AnyKey, we don't want to check the cache because a KeyedService.AnyKey registration
// will "hide" all the other service registration
if (!itemType.IsConstructedGenericType &&
!KeyedService.AnyKey.Equals(cacheKey.ServiceKey) &&
!isAnyKeyLookup &&
_descriptorLookup.TryGetValue(cacheKey, out ServiceDescriptorCacheItem descriptors))
{
callSites = new ServiceCallSite[descriptors.Count];
Expand Down Expand Up @@ -317,19 +319,25 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
int slot = 0;
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateExact(_descriptors[i], cacheKey, callSiteChain, slot) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateExact(_descriptors[i], registrationKey, callSiteChain, slot) is { } callSite)
{
AddCallSite(callSite, i);
}
}
}
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
if (KeysMatch(_descriptors[i].ServiceKey, cacheKey.ServiceKey))
if (KeysMatch(cacheKey.ServiceKey, _descriptors[i].ServiceKey))
{
if (TryCreateOpenGeneric(_descriptors[i], cacheKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
// Special case for AnyKey: we don't want to add in cache a mapping AnyKey -> specific type,
// so we need to ask creation with the original identity of the descriptor
var registrationKey = isAnyKeyLookup ? ServiceIdentifier.FromDescriptor(_descriptors[i]) : cacheKey;
if (TryCreateOpenGeneric(_descriptors[i], registrationKey, callSiteChain, slot, throwOnConstraintViolation: false) is { } callSite)
{
AddCallSite(callSite, i);
}
Expand Down Expand Up @@ -360,6 +368,32 @@ void AddCallSite(ServiceCallSite callSite, int index)
{
callSiteChain.Remove(serviceIdentifier);
}

static bool KeysMatch(object? lookupKey, object? descriptorKey)
{
if (lookupKey == null && descriptorKey == null)
{
// Both are non keyed services
return true;
}

if (lookupKey != null && descriptorKey != null)
{
// Both are keyed services

// We don't want to return AnyKey registration, so ignore it
if (descriptorKey.Equals(KeyedService.AnyKey))
return false;

// Check if both keys are equal, or if the lookup key
// should matches all keys (except AnyKey)
return lookupKey.Equals(descriptorKey)
|| lookupKey.Equals(KeyedService.AnyKey);
}

// One is a keyed service, one is not
return false;
}
}

private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB)
Expand Down Expand Up @@ -693,24 +727,6 @@ internal bool IsService(ServiceIdentifier serviceIdentifier)
serviceType == typeof(IServiceProviderIsKeyedService);
}

/// <summary>
/// Returns true if both keys are null or equals, or if key1 is KeyedService.AnyKey and key2 is not null
/// </summary>
private static bool KeysMatch(object? key1, object? key2)
{
if (key1 == null && key2 == null)
return true;

if (key1 != null && key2 != null)
{
return key1.Equals(key2)
|| key1.Equals(KeyedService.AnyKey)
|| key2.Equals(KeyedService.AnyKey);
}

return false;
}

private struct ServiceDescriptorCacheItem
{
[DisallowNull]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@ internal static void ThrowObjectDisposedException()
{
throw new ObjectDisposedException(nameof(IServiceProvider));
}

[DoesNotReturn]
[MethodImpl(MethodImplOptions.NoInlining)]
internal static void ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService()
{
throw new InvalidOperationException(SR.Format(SR.KeyedServiceAnyKeyUsedToResolveService, nameof(IServiceProvider), nameof(IServiceScopeFactory)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,46 @@ internal ServiceProvider(ICollection<ServiceDescriptor> serviceDescriptors, Serv
/// <param name="serviceType">The type of the service to get.</param>
/// <param name="serviceKey">The key of the service to get.</param>
/// <returns>The keyed service.</returns>
/// <exception cref="InvalidOperationException">The <see cref="KeyedService.AnyKey"/> value is used for <paramref name="serviceKey"/>
/// when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
/// </exception>
public object? GetKeyedService(Type serviceType, object? serviceKey)
=> GetKeyedService(serviceType, serviceKey, Root);

internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
=> GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
}

/// <summary>
/// Gets the service object of the specified type.
/// </summary>
/// <param name="serviceType">The type of the service to get.</param>
/// <param name="serviceKey">The key of the service to get.</param>
/// <returns>The keyed service.</returns>
/// <exception cref="InvalidOperationException">The service wasn't found.</exception>
/// <exception cref="InvalidOperationException">The service wasn't found or the <see cref="KeyedService.AnyKey"/> value is used
/// for <paramref name="serviceKey"/> when <paramref name="serviceType"/> is not an enumerable based on <see cref="IEnumerable{T}"/>.
/// </exception>
public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> GetRequiredKeyedService(serviceType, serviceKey, Root);
{
if (serviceKey == KeyedService.AnyKey)
{
if (!serviceType.IsGenericType || serviceType.GetGenericTypeDefinition() != typeof(IEnumerable<>))
{
ThrowHelper.ThrowInvalidOperationException_KeyedServiceAnyKeyUsedToResolveService();
}
}

return GetRequiredKeyedService(serviceType, serviceKey, Root);
}

internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
{
Expand Down
Loading