Skip to content

Commit c06d77a

Browse files
authored
Fix validation on build (#87354)
1 parent 3d6b1ff commit c06d77a

File tree

9 files changed

+86
-32
lines changed

9 files changed

+86
-32
lines changed

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,9 @@ void AddCallSite(ServiceCallSite callSite, int index)
316316
callSitesByIndex.Add(new(index, callSite));
317317
}
318318
}
319-
320-
ResultCache resultCache = ResultCache.None;
321-
if (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
322-
{
323-
resultCache = new ResultCache(cacheLocation, callSiteKey);
324-
}
325-
319+
ResultCache resultCache = (cacheLocation == CallSiteResultCacheLocation.Scope || cacheLocation == CallSiteResultCacheLocation.Root)
320+
? new ResultCache(cacheLocation, callSiteKey)
321+
: new ResultCache(CallSiteResultCacheLocation.None, callSiteKey);
326322
return _callSiteCache[callSiteKey] = new IEnumerableCallSite(resultCache, itemType, callSites);
327323
}
328324
finally

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteValidator.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,23 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
1010
internal sealed class CallSiteValidator: CallSiteVisitor<CallSiteValidator.CallSiteValidatorState, Type?>
1111
{
1212
// Keys are services being resolved via GetService, values - first scoped service in their call site tree
13-
private readonly ConcurrentDictionary<Type, Type> _scopedServices = new ConcurrentDictionary<Type, Type>();
13+
private readonly ConcurrentDictionary<ServiceCacheKey, Type> _scopedServices = new ConcurrentDictionary<ServiceCacheKey, Type>();
1414

1515
public void ValidateCallSite(ServiceCallSite callSite)
1616
{
1717
Type? scoped = VisitCallSite(callSite, default);
1818
if (scoped != null)
1919
{
20-
_scopedServices[callSite.ServiceType] = scoped;
20+
_scopedServices[callSite.Cache.Key] = scoped;
2121
}
2222
}
2323

24-
public void ValidateResolution(Type serviceType, IServiceScope scope, IServiceScope rootScope)
24+
public void ValidateResolution(ServiceCallSite callSite, IServiceScope scope, IServiceScope rootScope)
2525
{
2626
if (ReferenceEquals(scope, rootScope)
27-
&& _scopedServices.TryGetValue(serviceType, out Type? scopedService))
27+
&& _scopedServices.TryGetValue(callSite.Cache.Key, out Type? scopedService))
2828
{
29+
Type serviceType = callSite.ServiceType;
2930
if (serviceType == scopedService)
3031
{
3132
throw new InvalidOperationException(

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ConstantCallSite.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ internal sealed class ConstantCallSite : ServiceCallSite
1010
private readonly Type _serviceType;
1111
internal object? DefaultValue => Value;
1212

13-
public ConstantCallSite(Type serviceType, object? defaultValue): base(ResultCache.None)
13+
public ConstantCallSite(Type serviceType, object? defaultValue) : base(ResultCache.None(serviceType))
1414
{
1515
_serviceType = serviceType ?? throw new ArgumentNullException(nameof(serviceType));
1616
if (defaultValue != null && !serviceType.IsInstanceOfType(defaultValue))

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ResultCache.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
88
{
99
internal struct ResultCache
1010
{
11-
public static ResultCache None { get; } = new ResultCache(CallSiteResultCacheLocation.None, ServiceCacheKey.Empty);
11+
public static ResultCache None(Type serviceType)
12+
{
13+
var cacheKey = new ServiceCacheKey(serviceType, 0);
14+
return new ResultCache(CallSiteResultCacheLocation.None, cacheKey);
15+
}
1216

1317
internal ResultCache(CallSiteResultCacheLocation lifetime, ServiceCacheKey cacheKey)
1418
{

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceCacheKey.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
88
{
99
internal readonly struct ServiceCacheKey : IEquatable<ServiceCacheKey>
1010
{
11-
public static ServiceCacheKey Empty { get; } = new ServiceCacheKey(null, 0);
12-
1311
/// <summary>
1412
/// Type of service being cached
1513
/// </summary>

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderCallSite.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
77
{
88
internal sealed class ServiceProviderCallSite : ServiceCallSite
99
{
10-
public ServiceProviderCallSite() : base(ResultCache.None)
10+
public ServiceProviderCallSite() : base(ResultCache.None(typeof(IServiceProvider)))
1111
{
1212
}
1313

src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ private void OnCreate(ServiceCallSite callSite)
120120
_callSiteValidator?.ValidateCallSite(callSite);
121121
}
122122

123-
private void OnResolve(Type serviceType, IServiceScope scope)
123+
private void OnResolve(ServiceCallSite callSite, IServiceScope scope)
124124
{
125-
_callSiteValidator?.ValidateResolution(serviceType, scope, Root);
125+
_callSiteValidator?.ValidateResolution(callSite, scope, Root);
126126
}
127127

128128
internal object? GetService(Type serviceType, ServiceProviderEngineScope serviceProviderEngineScope)
@@ -133,8 +133,6 @@ private void OnResolve(Type serviceType, IServiceScope scope)
133133
}
134134

135135
Func<ServiceProviderEngineScope, object?> realizedService = _realizedServices.GetOrAdd(serviceType, _createServiceAccessor);
136-
OnResolve(serviceType, serviceProviderEngineScope);
137-
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
138136
var result = realizedService.Invoke(serviceProviderEngineScope);
139137
System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceType));
140138
return result;
@@ -173,10 +171,20 @@ private void ValidateService(ServiceDescriptor descriptor)
173171
if (callSite.Cache.Location == CallSiteResultCacheLocation.Root)
174172
{
175173
object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root);
176-
return scope => value;
174+
return scope =>
175+
{
176+
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
177+
return value;
178+
};
177179
}
178180

179-
return _engine.RealizeService(callSite);
181+
Func<ServiceProviderEngineScope, object?> realizedService = _engine.RealizeService(callSite);
182+
return scope =>
183+
{
184+
OnResolve(callSite, scope);
185+
DependencyInjectionEventSource.Log.ServiceResolved(this, serviceType);
186+
return realizedService(scope);
187+
};
180188
}
181189

182190
return _ => null;

src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -786,17 +786,10 @@ public void CreateCallSite_EnumberableCachedAtLowestLevel(ServiceDescriptor[] de
786786
var callSite = factory(typeof(IEnumerable<FakeService>));
787787

788788
var expectedLocation = (CallSiteResultCacheLocation)expectedCacheLocation;
789-
Assert.Equal(expectedLocation, callSite.Cache.Location);
790789

791-
if (expectedLocation != CallSiteResultCacheLocation.None)
792-
{
793-
Assert.Equal(0, callSite.Cache.Key.Slot);
794-
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
795-
}
796-
else
797-
{
798-
Assert.Equal(ResultCache.None, callSite.Cache);
799-
}
790+
Assert.Equal(expectedLocation, callSite.Cache.Location);
791+
Assert.Equal(0, callSite.Cache.Key.Slot);
792+
Assert.Equal(typeof(IEnumerable<FakeService>), callSite.Cache.Key.Type);
800793
}
801794

802795
[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]

src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
57
using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
68
using Xunit;
79

@@ -97,6 +99,49 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTra
9799
Assert.Equal($"Cannot resolve '{typeof(IFoo)}' from root provider because it requires scoped service '{typeof(IBar)}'.", exception.Message);
98100
}
99101

102+
[Theory]
103+
[InlineData(true)]
104+
[InlineData(false)]
105+
public void GetService_DoesNotThrow_WhenGetServiceForPolymorphicServiceIsCalledOnRoot_AndTheLastOneIsNotScoped(bool validateOnBuild)
106+
{
107+
// Arrange
108+
var serviceCollection = new ServiceCollection();
109+
serviceCollection.AddScoped<IBar, Bar>();
110+
serviceCollection.AddTransient<IBar, Bar3>();
111+
using var serviceProvider = serviceCollection.BuildServiceProvider(new ServiceProviderOptions
112+
{
113+
ValidateScopes = true,
114+
ValidateOnBuild = validateOnBuild
115+
});
116+
117+
// Act
118+
var actual = serviceProvider.GetService<IBar>();
119+
120+
// Assert
121+
Assert.IsType<Bar3>(actual);
122+
}
123+
124+
[Fact]
125+
public void ScopeValidation_ShouldBeAbleToDistingushGenericCollections_WhenGetServiceIsCalledOnRoot()
126+
{
127+
// Arrange
128+
var serviceCollection = new ServiceCollection();
129+
serviceCollection.AddTransient<IBar, Bar>();
130+
serviceCollection.AddScoped<IBar, Bar3>();
131+
132+
serviceCollection.AddTransient<IBaz, Baz>();
133+
serviceCollection.AddTransient<IBaz, Baz2>();
134+
135+
// Act
136+
using var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true);
137+
Assert.Throws<InvalidOperationException>(() => serviceProvider.GetService<IEnumerable<IBar>>());
138+
var actual = serviceProvider.GetService<IEnumerable<IBaz>>();
139+
140+
// Assert
141+
Assert.IsType<Baz>(actual.First());
142+
Assert.IsType<Baz2>(actual.Last());
143+
}
144+
100145
[Fact]
101146
public void GetService_DoesNotThrow_WhenScopeFactoryIsInjectedIntoSingleton()
102147
{
@@ -206,13 +251,18 @@ private class Bar : IBar
206251
{
207252
}
208253

254+
209255
private class Bar2 : IBar
210256
{
211257
public Bar2(IBaz baz)
212258
{
213259
}
214260
}
215261

262+
private class Bar3 : IBar
263+
{
264+
}
265+
216266
private interface IBaz
217267
{
218268
}
@@ -221,6 +271,10 @@ private class Baz : IBaz
221271
{
222272
}
223273

274+
private class Baz2 : IBaz
275+
{
276+
}
277+
224278
private class BazRecursive : IBaz
225279
{
226280
public BazRecursive(IBaz baz)

0 commit comments

Comments
 (0)