Skip to content

Avoid dictionary lookup for singleton services #52035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection.ServiceLookup;

namespace Microsoft.Extensions.DependencyInjection
{
internal class ScopeState
{
public IDictionary<ServiceCacheKey, object> ResolvedServices { get; }
public Dictionary<ServiceCacheKey, object> ResolvedServices { get; }
public List<object> Disposables { get; set; }

public int DisposableServicesCount => Disposables?.Count ?? 0;
public int ResolvedServicesCount => ResolvedServices.Count;

public ScopeState(bool isRoot)
public ScopeState()
{
// When isRoot is true to reduce lock contention for singletons upon resolve we use a concurrent dictionary.
ResolvedServices = isRoot ? new ConcurrentDictionary<ServiceCacheKey, object>() : new Dictionary<ServiceCacheKey, object>();
ResolvedServices = new Dictionary<ServiceCacheKey, object>();
}

public void Track(ServiceProviderEngine engine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ internal sealed class CallSiteFactory
{
private const int DefaultSlot = 0;
private readonly ServiceDescriptor[] _descriptors;
private readonly ConcurrentDictionary<Type, ServiceCallSite> _callSiteCache = new ConcurrentDictionary<Type, ServiceCallSite>();
private readonly ConcurrentDictionary<ServiceCacheKey, ServiceCallSite> _callSiteCache = new ConcurrentDictionary<ServiceCacheKey, ServiceCallSite>();
private readonly Dictionary<Type, ServiceDescriptorCacheItem> _descriptorLookup = new Dictionary<Type, ServiceDescriptorCacheItem>();

private readonly StackGuard _stackGuard;
Expand Down Expand Up @@ -77,7 +77,7 @@ private void Populate()
}

internal ServiceCallSite GetCallSite(Type serviceType, CallSiteChain callSiteChain) =>
_callSiteCache.TryGetValue(serviceType, out ServiceCallSite site) ? site :
_callSiteCache.TryGetValue(new ServiceCacheKey(serviceType, DefaultSlot), out ServiceCallSite site) ? site :
CreateCallSite(serviceType, callSiteChain);

internal ServiceCallSite GetCallSite(ServiceDescriptor serviceDescriptor, CallSiteChain callSiteChain)
Expand All @@ -104,8 +104,6 @@ private ServiceCallSite CreateCallSite(Type serviceType, CallSiteChain callSiteC
TryCreateOpenGeneric(serviceType, callSiteChain) ??
TryCreateEnumerable(serviceType, callSiteChain);

_callSiteCache[serviceType] = callSite;

return callSite;
}

Expand All @@ -132,6 +130,12 @@ private ServiceCallSite TryCreateOpenGeneric(Type serviceType, CallSiteChain cal

private ServiceCallSite TryCreateEnumerable(Type serviceType, CallSiteChain callSiteChain)
{
ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, DefaultSlot);
if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite))
{
return serviceCallSite;
}

try
{
callSiteChain.Add(serviceType);
Expand Down Expand Up @@ -188,10 +192,10 @@ private ServiceCallSite TryCreateEnumerable(Type serviceType, CallSiteChain call
ResultCache resultCache = ResultCache.None;
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
{
resultCache = new ResultCache(cacheLocation, new ServiceCacheKey(serviceType, DefaultSlot));
resultCache = new ResultCache(cacheLocation, callSiteKey);
}

return new IEnumerableCallSite(resultCache, itemType, callSites.ToArray());
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites.ToArray());
}

return null;
Expand All @@ -211,6 +215,12 @@ private ServiceCallSite TryCreateExact(ServiceDescriptor descriptor, Type servic
{
if (serviceType == descriptor.ServiceType)
{
ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, slot);
if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite))
{
return serviceCallSite;
}

ServiceCallSite callSite;
var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot);
if (descriptor.ImplementationInstance != null)
Expand All @@ -230,7 +240,7 @@ private ServiceCallSite TryCreateExact(ServiceDescriptor descriptor, Type servic
throw new InvalidOperationException(SR.InvalidServiceDescriptor);
}

return callSite;
return _callSiteCache[callSiteKey] = callSite;
}

return null;
Expand All @@ -241,6 +251,12 @@ private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type
if (serviceType.IsConstructedGenericType &&
serviceType.GetGenericTypeDefinition() == descriptor.ServiceType)
{
ServiceCacheKey callSiteKey = new ServiceCacheKey(serviceType, slot);
if (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite serviceCallSite))
{
return serviceCallSite;
}

Debug.Assert(descriptor.ImplementationType != null, "descriptor.ImplementationType != null");
var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot);
Type closedType;
Expand All @@ -258,7 +274,7 @@ private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type
return null;
}

return CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain);
return _callSiteCache[callSiteKey] = CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain);
}

return null;
Expand Down Expand Up @@ -406,7 +422,7 @@ private ServiceCallSite[] CreateArgumentCallSites(

public void Add(Type type, ServiceCallSite serviceCallSite)
{
_callSiteCache[type] = serviceCallSite;
_callSiteCache[new ServiceCacheKey(type, DefaultSlot)] = serviceCallSite;
}

private struct ServiceDescriptorCacheItem
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Reflection;
using System.Runtime.ExceptionServices;
using System.Threading;
Expand Down Expand Up @@ -60,21 +59,31 @@ protected override object VisitConstructor(ConstructorCallSite constructorCallSi

protected override object VisitRootCache(ServiceCallSite callSite, RuntimeResolverContext context)
{
var lockType = RuntimeResolverLock.Root;
bool lockTaken = false;

// using more granular locking (per singleton) for the root
Monitor.Enter(callSite, ref lockTaken);
try
if (callSite.Value is object value)
{
return ResolveService(callSite, context, lockType, serviceProviderEngine: context.Scope.Engine.Root);
// Value already calculated, return it directly
return value;
}
finally

var lockType = RuntimeResolverLock.Root;
ServiceProviderEngineScope serviceProviderEngine = context.Scope.Engine.Root;

lock (callSite)
{
if (lockTaken)
// Lock the callsite and check if another thread already cached the value
if (callSite.Value is object resolved)
{
Monitor.Exit(callSite);
return resolved;
}

resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext
{
Scope = serviceProviderEngine,
AcquiredLocks = context.AcquiredLocks | lockType
});
serviceProviderEngine.CaptureDisposable(resolved);
callSite.Value = resolved;
return resolved;
}
}

Expand All @@ -91,7 +100,7 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte
{
bool lockTaken = false;
object sync = serviceProviderEngine.Sync;

Dictionary<ServiceCacheKey, object> resolvedServices = serviceProviderEngine.ResolvedServices;
// Taking locks only once allows us to fork resolution process
// on another thread without causing the deadlock because we
// always know that we are going to wait the other thread to finish before
Expand All @@ -103,7 +112,21 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte

try
{
return ResolveService(callSite, context, lockType, serviceProviderEngine);
// Note: This method has already taken lock by the caller for resolution and access synchronization.
// For scoped: takes a dictionary as both a resolution lock and a dictionary access lock.
if (resolvedServices.TryGetValue(callSite.Cache.Key, out object resolved))
{
return resolved;
}

resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext
{
Scope = serviceProviderEngine,
AcquiredLocks = context.AcquiredLocks | lockType
});
serviceProviderEngine.CaptureDisposable(resolved);
resolvedServices.Add(callSite.Cache.Key, resolved);
return resolved;
}
finally
{
Expand All @@ -114,32 +137,6 @@ private object VisitCache(ServiceCallSite callSite, RuntimeResolverContext conte
}
}

private object ResolveService(ServiceCallSite callSite, RuntimeResolverContext context, RuntimeResolverLock lockType, ServiceProviderEngineScope serviceProviderEngine)
{
IDictionary<ServiceCacheKey, object> resolvedServices = serviceProviderEngine.ResolvedServices;

// Note: This method has already taken lock by the caller for resolution and access synchronization.
// For root: uses a concurrent dictionary and takes a per singleton lock for resolution.
// For scoped: takes a dictionary as both a resolution lock and a dictionary access lock.
Debug.Assert(
(lockType == RuntimeResolverLock.Root && resolvedServices is ConcurrentDictionary<ServiceCacheKey, object>) ||
(lockType == RuntimeResolverLock.Scope && Monitor.IsEntered(serviceProviderEngine.Sync)));

if (resolvedServices.TryGetValue(callSite.Cache.Key, out object resolved))
{
return resolved;
}

resolved = VisitCallSiteMain(callSite, new RuntimeResolverContext
{
Scope = serviceProviderEngine,
AcquiredLocks = context.AcquiredLocks | lockType
});
serviceProviderEngine.CaptureDisposable(resolved);
resolvedServices.Add(callSite.Cache.Key, resolved);
return resolved;
}

protected override object VisitConstant(ConstantCallSite constantCallSite, RuntimeResolverContext context)
{
return constantCallSite.DefaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
internal sealed class ConstantCallSite : ServiceCallSite
{
private readonly Type _serviceType;
internal object DefaultValue { get; }
internal object DefaultValue => Value;

public ConstantCallSite(Type serviceType, object defaultValue): base(ResultCache.None)
{
Expand All @@ -18,7 +18,7 @@ public ConstantCallSite(Type serviceType, object defaultValue): base(ResultCache
throw new ArgumentException(SR.Format(SR.ConstantCantBeConvertedToServiceType, defaultValue.GetType(), serviceType));
}

DefaultValue = defaultValue;
Value = defaultValue;
}

public override Type ServiceType => DefaultValue?.GetType() ?? _serviceType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
{
internal struct ServiceCacheKey: IEquatable<ServiceCacheKey>
internal readonly struct ServiceCacheKey : IEquatable<ServiceCacheKey>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not directly related to your changes, but the HashCode calculation will throw for the static Empty, the Type is null then, and GetHashCode will throw a null ref exception.
I suggest changing it to HashCode.Combine(Type, Slot) which takes care of the null

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do that in a different PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created #52115 to fix that nullref exception.

{
public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ protected ServiceCallSite(ResultCache cache)
public abstract Type ImplementationType { get; }
public abstract CallSiteKind Kind { get; }
public ResultCache Cache { get; }
public object Value { get; set; }

public bool CaptureDisposable =>
ImplementationType == null ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal abstract class ServiceProviderEngine : IServiceProviderEngine, IService
protected ServiceProviderEngine(IEnumerable<ServiceDescriptor> serviceDescriptors)
{
_createServiceAccessor = CreateServiceAccessor;
Root = new ServiceProviderEngineScope(this, isRoot: true);
Root = new ServiceProviderEngineScope(this);
RuntimeResolver = new CallSiteRuntimeResolver();
CallSiteFactory = new CallSiteFactory(serviceDescriptors);
CallSiteFactory.Add(typeof(IServiceProvider), new ServiceProviderCallSite());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvid
private bool _disposed;
private readonly ScopeState _state;

public ServiceProviderEngineScope(ServiceProviderEngine engine, bool isRoot = false)
public ServiceProviderEngineScope(ServiceProviderEngine engine)
{
Engine = engine;
_state = new ScopeState(isRoot);
_state = new ScopeState();
}

internal IDictionary<ServiceCacheKey, object> ResolvedServices => _state.ResolvedServices;
internal Dictionary<ServiceCacheKey, object> ResolvedServices => _state.ResolvedServices;

// This lock protects state on the scope, in particular, for the root scope, it protects
// the list of disposable entries only, since ResolvedServices is a concurrent dictionary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ public void BuiltExpressionWillReturnResolvedServiceWhenAppropriate(
var compiledCallSite = CompileCallSite(callSite, provider);
var compiledCollectionCallSite = CompileCallSite(collectionCallSite, provider);

var service1 = Invoke(callSite, provider);
var service2 = compiledCallSite(provider.Root);
var serviceEnumerator = ((IEnumerable)compiledCollectionCallSite(provider.Root)).GetEnumerator();
using var scope = (ServiceProviderEngineScope)provider.CreateScope();

var service1 = Invoke(callSite, scope);
var service2 = compiledCallSite(scope);
var serviceEnumerator = ((IEnumerable)compiledCollectionCallSite(scope)).GetEnumerator();

Assert.NotNull(service1);
Assert.True(compare(service1, service2));
Expand All @@ -114,10 +116,12 @@ public void BuiltExpressionCanResolveNestedScopedService()
var callSite = provider.CallSiteFactory.GetCallSite(typeof(ServiceC), new CallSiteChain());
var compiledCallSite = CompileCallSite(callSite, provider);

var serviceC = (ServiceC)compiledCallSite(provider.Root);
using var scope = (ServiceProviderEngineScope)provider.CreateScope();

var serviceC = (ServiceC)compiledCallSite(scope);

Assert.NotNull(serviceC.ServiceB.ServiceA);
Assert.Equal(serviceC, Invoke(callSite, provider));
Assert.Equal(serviceC, Invoke(callSite, scope));
}

[Theory]
Expand Down Expand Up @@ -371,9 +375,9 @@ public void Dispose()
}
}

private static object Invoke(ServiceCallSite callSite, ServiceProviderEngine provider)
private static object Invoke(ServiceCallSite callSite, ServiceProviderEngineScope scope)
{
return CallSiteRuntimeResolver.Resolve(callSite, provider.Root);
return CallSiteRuntimeResolver.Resolve(callSite, scope);
}

private static Func<ServiceProviderEngineScope, object> CompileCallSite(ServiceCallSite callSite, ServiceProviderEngine engine)
Expand Down