diff --git a/src/EFCore.Relational/Design/EntityFrameworkRelationalDesignServicesBuilder.cs b/src/EFCore.Relational/Design/EntityFrameworkRelationalDesignServicesBuilder.cs index 6f8ac836426..5d4e8fefd44 100644 --- a/src/EFCore.Relational/Design/EntityFrameworkRelationalDesignServicesBuilder.cs +++ b/src/EFCore.Relational/Design/EntityFrameworkRelationalDesignServicesBuilder.cs @@ -62,12 +62,11 @@ public EntityFrameworkRelationalDesignServicesBuilder(IServiceCollection service /// Gets the for the given service type. /// /// The type that defines the service API. - /// The for the type. - /// when the type is not an EF service. - protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType) + /// The for the type or if it's not an EF service. + protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType) => RelationalServices.TryGetValue(serviceType, out var characteristics) ? characteristics - : base.GetServiceCharacteristics(serviceType); + : base.TryGetServiceCharacteristics(serviceType); /// /// Registers default implementations of all services, including relational services, not already diff --git a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs index 237602fc1c5..1910bf13bce 100644 --- a/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs +++ b/src/EFCore.Relational/Infrastructure/EntityFrameworkRelationalServicesBuilder.cs @@ -117,12 +117,11 @@ public EntityFrameworkRelationalServicesBuilder(IServiceCollection serviceCollec /// Gets the for the given service type. /// /// The type that defines the service API. - /// The for the type. - /// when the type is not an EF service. - protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType) + /// The for the type or if it's not an EF service. + protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType) => RelationalServices.TryGetValue(serviceType, out var characteristics) ? characteristics - : base.GetServiceCharacteristics(serviceType); + : base.TryGetServiceCharacteristics(serviceType); /// /// Registers default implementations of all services, including relational services, not already diff --git a/src/EFCore/Design/EntityFrameworkDesignServicesBuilder.cs b/src/EFCore/Design/EntityFrameworkDesignServicesBuilder.cs index eeb00a5344e..2aedd1ed829 100644 --- a/src/EFCore/Design/EntityFrameworkDesignServicesBuilder.cs +++ b/src/EFCore/Design/EntityFrameworkDesignServicesBuilder.cs @@ -65,12 +65,11 @@ public EntityFrameworkDesignServicesBuilder(IServiceCollection serviceCollection /// Gets the for the given service type. /// /// The type that defines the service API. - /// The for the type. - /// when the type is not an EF service. - protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType) + /// The for the type or if it's not an EF service. + protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType) => Services.TryGetValue(serviceType, out var characteristics) ? characteristics - : base.GetServiceCharacteristics(serviceType); + : base.TryGetServiceCharacteristics(serviceType); /// /// Registers default implementations of all services, including relational services, not already diff --git a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs index 7b7104cc5ec..0e60016bd2d 100644 --- a/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs +++ b/src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs @@ -176,14 +176,22 @@ public EntityFrameworkServicesBuilder(IServiceCollection serviceCollection) /// when the type is not an EF service. protected virtual ServiceCharacteristics GetServiceCharacteristics(Type serviceType) { - if (!CoreServices.TryGetValue(serviceType, out var characteristics)) - { - throw new InvalidOperationException(CoreStrings.NotAnEFService(serviceType.Name)); - } - - return characteristics; + var characteristics = TryGetServiceCharacteristics(serviceType); + return characteristics == null + ? throw new InvalidOperationException(CoreStrings.NotAnEFService(serviceType.Name)) + : characteristics.Value; } + /// + /// Gets the for the given service type. + /// + /// The type that defines the service API. + /// The for the type or if it's not an EF service. + protected virtual ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType) + => !CoreServices.TryGetValue(serviceType, out var characteristics) + ? null + : characteristics; + /// /// Database providers should call this method for access to the underlying /// such that provider-specific services can be registered. @@ -196,8 +204,18 @@ public virtual EntityFrameworkServicesBuilder TryAddProviderSpecificServices(Act { Check.NotNull(serviceMap, nameof(serviceMap)); + ServiceCollectionMap.Validate = serviceType => + { + if (TryGetServiceCharacteristics(serviceType) != null) + { + throw new InvalidOperationException(CoreStrings.NotAProviderService(serviceType.Name)); + } + }; + serviceMap(ServiceCollectionMap); + ServiceCollectionMap.Validate = null; + return this; } diff --git a/src/EFCore/Infrastructure/ServiceCollectionMap.cs b/src/EFCore/Infrastructure/ServiceCollectionMap.cs index c7abb6f0512..797436eb13c 100644 --- a/src/EFCore/Infrastructure/ServiceCollectionMap.cs +++ b/src/EFCore/Infrastructure/ServiceCollectionMap.cs @@ -42,6 +42,8 @@ public ServiceCollectionMap(IServiceCollection serviceCollection) public virtual IServiceCollection ServiceCollection => _map.ServiceCollection; + internal Action? Validate { get; set; } + /// /// Adds a service implemented by the given concrete /// type if no service for the given service type has already been registered. @@ -124,6 +126,8 @@ public virtual ServiceCollectionMap TryAdd( Check.NotNull(serviceType, nameof(serviceType)); Check.NotNull(implementationType, nameof(implementationType)); + Validate?.Invoke(serviceType); + var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); if (indexes.Count == 0) { @@ -254,6 +258,8 @@ public virtual ServiceCollectionMap TryAdd( Check.NotNull(serviceType, nameof(serviceType)); Check.NotNull(factory, nameof(factory)); + Validate?.Invoke(serviceType); + var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); if (indexes.Count == 0) { @@ -285,6 +291,8 @@ public virtual ServiceCollectionMap TryAddSingleton(Type serviceType, object imp { Check.NotNull(serviceType, nameof(serviceType)); + Validate?.Invoke(serviceType); + var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); if (indexes.Count == 0) { @@ -383,6 +391,8 @@ public virtual ServiceCollectionMap TryAddEnumerable( Check.NotNull(serviceType, nameof(serviceType)); Check.NotNull(implementationType, nameof(implementationType)); + Validate?.Invoke(serviceType); + var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); if (indexes.All(i => TryGetImplementationType(ServiceCollection[i]) != implementationType)) { @@ -457,6 +467,8 @@ public virtual ServiceCollectionMap TryAddEnumerable( Check.NotNull(implementationType, nameof(implementationType)); Check.NotNull(factory, nameof(factory)); + Validate?.Invoke(serviceType); + var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); if (indexes.All(i => TryGetImplementationType(ServiceCollection[i]) != implementationType)) { @@ -491,6 +503,8 @@ public virtual ServiceCollectionMap TryAddSingletonEnumerable(Type serviceType, Check.NotNull(serviceType, nameof(serviceType)); Check.NotNull(implementation, nameof(implementation)); + Validate?.Invoke(serviceType); + var implementationType = implementation.GetType(); var indexes = _map.GetOrCreateDescriptorIndexes(serviceType); diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs index 3521e7988d6..7071010ff35 100644 --- a/src/EFCore/Properties/CoreStrings.Designer.cs +++ b/src/EFCore/Properties/CoreStrings.Designer.cs @@ -1998,6 +1998,14 @@ public static string NotAnEFService(object? service) GetString("NotAnEFService", nameof(service)), service); + /// + /// The database provider attempted to register an implementation of the '{service}' service. This is a service defined by Entity Framework and as such must not be registered using the 'TryAddProviderSpecificServices' method. + /// + public static string NotAProviderService(object? service) + => string.Format( + GetString("NotAProviderService", nameof(service)), + service); + /// /// The entity type '{entityType}' cannot inherit from '{baseEntityType}' because '{clrType}' is not a descendant of '{baseClrType}'. /// diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx index c2e9cef62d0..08cd65b8f47 100644 --- a/src/EFCore/Properties/CoreStrings.resx +++ b/src/EFCore/Properties/CoreStrings.resx @@ -1196,6 +1196,9 @@ The database provider attempted to register an implementation of the '{service}' service. This is not a service defined by Entity Framework and as such must be registered as a provider-specific service using the 'TryAddProviderSpecificServices' method. + + The database provider attempted to register an implementation of the '{service}' service. This is a service defined by Entity Framework and as such must not be registered using the 'TryAddProviderSpecificServices' method. + The entity type '{entityType}' cannot inherit from '{baseEntityType}' because '{clrType}' is not a descendant of '{baseClrType}'. diff --git a/test/EFCore.Tests/Infrastructure/EntityFrameworkServicesBuilderTest.cs b/test/EFCore.Tests/Infrastructure/EntityFrameworkServicesBuilderTest.cs index 6212af36be8..9c3ea1c8a3d 100644 --- a/test/EFCore.Tests/Infrastructure/EntityFrameworkServicesBuilderTest.cs +++ b/test/EFCore.Tests/Infrastructure/EntityFrameworkServicesBuilderTest.cs @@ -28,6 +28,18 @@ public void Throws_when_adding_non_EF_service() Assert.Throws(() => builder.TryAdd()).Message); } + [ConditionalFact] + public void Throws_when_adding_EF_service() + { + var serviceCollection = new ServiceCollection(); + var builder = new EntityFrameworkServicesBuilder(serviceCollection); + + Assert.Equal( + CoreStrings.NotAProviderService("IConcurrencyDetector"), + Assert.Throws(() => builder.TryAddProviderSpecificServices( + s => s.TryAddScoped())).Message); + } + [ConditionalFact] public void Can_register_scoped_service_with_concrete_implementation() {