Skip to content

Commit

Permalink
Add a check to TryAddProviderSpecificServices
Browse files Browse the repository at this point in the history
  • Loading branch information
AndriySvyryd authored Aug 24, 2021
1 parent 8d30c80 commit cc538cd
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ public EntityFrameworkRelationalDesignServicesBuilder(IServiceCollection service
/// Gets the <see cref="ServiceCharacteristics" /> for the given service type.
/// </summary>
/// <param name="serviceType"> The type that defines the service API. </param>
/// <returns> The <see cref="ServiceCharacteristics" /> for the type. </returns>
/// <exception cref="InvalidOperationException"> when the type is not an EF service. </exception>
protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType)
/// <returns> The <see cref="ServiceCharacteristics" /> for the type or <see langword="null"/> if it's not an EF service. </returns>
protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType)
=> RelationalServices.TryGetValue(serviceType, out var characteristics)
? characteristics
: base.GetServiceCharacteristics(serviceType);
: base.TryGetServiceCharacteristics(serviceType);

/// <summary>
/// Registers default implementations of all services, including relational services, not already
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ public EntityFrameworkRelationalServicesBuilder(IServiceCollection serviceCollec
/// Gets the <see cref="ServiceCharacteristics" /> for the given service type.
/// </summary>
/// <param name="serviceType"> The type that defines the service API. </param>
/// <returns> The <see cref="ServiceCharacteristics" /> for the type. </returns>
/// <exception cref="InvalidOperationException"> when the type is not an EF service. </exception>
protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType)
/// <returns> The <see cref="ServiceCharacteristics" /> for the type or <see langword="null"/> if it's not an EF service. </returns>
protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType)
=> RelationalServices.TryGetValue(serviceType, out var characteristics)
? characteristics
: base.GetServiceCharacteristics(serviceType);
: base.TryGetServiceCharacteristics(serviceType);

/// <summary>
/// Registers default implementations of all services, including relational services, not already
Expand Down
7 changes: 3 additions & 4 deletions src/EFCore/Design/EntityFrameworkDesignServicesBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ public EntityFrameworkDesignServicesBuilder(IServiceCollection serviceCollection
/// Gets the <see cref="ServiceCharacteristics" /> for the given service type.
/// </summary>
/// <param name="serviceType"> The type that defines the service API. </param>
/// <returns> The <see cref="ServiceCharacteristics" /> for the type. </returns>
/// <exception cref="InvalidOperationException"> when the type is not an EF service. </exception>
protected override ServiceCharacteristics GetServiceCharacteristics(Type serviceType)
/// <returns> The <see cref="ServiceCharacteristics" /> for the type or <see langword="null"/> if it's not an EF service. </returns>
protected override ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType)
=> Services.TryGetValue(serviceType, out var characteristics)
? characteristics
: base.GetServiceCharacteristics(serviceType);
: base.TryGetServiceCharacteristics(serviceType);

/// <summary>
/// Registers default implementations of all services, including relational services, not already
Expand Down
30 changes: 24 additions & 6 deletions src/EFCore/Infrastructure/EntityFrameworkServicesBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,22 @@ public EntityFrameworkServicesBuilder(IServiceCollection serviceCollection)
/// <exception cref="InvalidOperationException"> when the type is not an EF service. </exception>
protected virtual ServiceCharacteristics GetServiceCharacteristics(Type serviceType)
{
if (!CoreServices.TryGetValue(serviceType, out var characteristics))
{
throw new InvalidOperationException(CoreStrings.NotAnEFService(serviceType.Name));
}

return characteristics;
var characteristics = TryGetServiceCharacteristics(serviceType);
return characteristics == null
? throw new InvalidOperationException(CoreStrings.NotAnEFService(serviceType.Name))
: characteristics.Value;
}

/// <summary>
/// Gets the <see cref="ServiceCharacteristics" /> for the given service type.
/// </summary>
/// <param name="serviceType"> The type that defines the service API. </param>
/// <returns> The <see cref="ServiceCharacteristics" /> for the type or <see langword="null"/> if it's not an EF service. </returns>
protected virtual ServiceCharacteristics? TryGetServiceCharacteristics(Type serviceType)
=> !CoreServices.TryGetValue(serviceType, out var characteristics)
? null
: characteristics;

/// <summary>
/// Database providers should call this method for access to the underlying
/// <see cref="ServiceCollectionMap" /> such that provider-specific services can be registered.
Expand All @@ -196,8 +204,18 @@ public virtual EntityFrameworkServicesBuilder TryAddProviderSpecificServices(Act
{
Check.NotNull(serviceMap, nameof(serviceMap));

ServiceCollectionMap.Validate = serviceType =>
{
if (TryGetServiceCharacteristics(serviceType) != null)
{
throw new InvalidOperationException(CoreStrings.NotAProviderService(serviceType.Name));
}
};

serviceMap(ServiceCollectionMap);

ServiceCollectionMap.Validate = null;

return this;
}

Expand Down
14 changes: 14 additions & 0 deletions src/EFCore/Infrastructure/ServiceCollectionMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public ServiceCollectionMap(IServiceCollection serviceCollection)
public virtual IServiceCollection ServiceCollection
=> _map.ServiceCollection;

internal Action<Type>? Validate { get; set; }

/// <summary>
/// Adds a <see cref="ServiceLifetime.Transient" /> service implemented by the given concrete
/// type if no service for the given service type has already been registered.
Expand Down Expand Up @@ -124,6 +126,8 @@ public virtual ServiceCollectionMap TryAdd(
Check.NotNull(serviceType, nameof(serviceType));
Check.NotNull(implementationType, nameof(implementationType));

Validate?.Invoke(serviceType);

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
if (indexes.Count == 0)
{
Expand Down Expand Up @@ -254,6 +258,8 @@ public virtual ServiceCollectionMap TryAdd(
Check.NotNull(serviceType, nameof(serviceType));
Check.NotNull(factory, nameof(factory));

Validate?.Invoke(serviceType);

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
if (indexes.Count == 0)
{
Expand Down Expand Up @@ -285,6 +291,8 @@ public virtual ServiceCollectionMap TryAddSingleton(Type serviceType, object imp
{
Check.NotNull(serviceType, nameof(serviceType));

Validate?.Invoke(serviceType);

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
if (indexes.Count == 0)
{
Expand Down Expand Up @@ -383,6 +391,8 @@ public virtual ServiceCollectionMap TryAddEnumerable(
Check.NotNull(serviceType, nameof(serviceType));
Check.NotNull(implementationType, nameof(implementationType));

Validate?.Invoke(serviceType);

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
if (indexes.All(i => TryGetImplementationType(ServiceCollection[i]) != implementationType))
{
Expand Down Expand Up @@ -457,6 +467,8 @@ public virtual ServiceCollectionMap TryAddEnumerable(
Check.NotNull(implementationType, nameof(implementationType));
Check.NotNull(factory, nameof(factory));

Validate?.Invoke(serviceType);

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
if (indexes.All(i => TryGetImplementationType(ServiceCollection[i]) != implementationType))
{
Expand Down Expand Up @@ -491,6 +503,8 @@ public virtual ServiceCollectionMap TryAddSingletonEnumerable(Type serviceType,
Check.NotNull(serviceType, nameof(serviceType));
Check.NotNull(implementation, nameof(implementation));

Validate?.Invoke(serviceType);

var implementationType = implementation.GetType();

var indexes = _map.GetOrCreateDescriptorIndexes(serviceType);
Expand Down
8 changes: 8 additions & 0 deletions src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,9 @@
<data name="NotAnEFService" xml:space="preserve">
<value>The database provider attempted to register an implementation of the '{service}' service. This is not a service defined by Entity Framework and as such must be registered as a provider-specific service using the 'TryAddProviderSpecificServices' method.</value>
</data>
<data name="NotAProviderService" xml:space="preserve">
<value>The database provider attempted to register an implementation of the '{service}' service. This is a service defined by Entity Framework and as such must not be registered using the 'TryAddProviderSpecificServices' method.</value>
</data>
<data name="NotAssignableClrBaseType" xml:space="preserve">
<value>The entity type '{entityType}' cannot inherit from '{baseEntityType}' because '{clrType}' is not a descendant of '{baseClrType}'.</value>
</data>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ public void Throws_when_adding_non_EF_service()
Assert.Throws<InvalidOperationException>(() => builder.TryAdd<Random, Random>()).Message);
}

[ConditionalFact]
public void Throws_when_adding_EF_service()
{
var serviceCollection = new ServiceCollection();
var builder = new EntityFrameworkServicesBuilder(serviceCollection);

Assert.Equal(
CoreStrings.NotAProviderService("IConcurrencyDetector"),
Assert.Throws<InvalidOperationException>(() => builder.TryAddProviderSpecificServices(
s => s.TryAddScoped<IConcurrencyDetector, FakeConcurrencyDetector>())).Message);
}

[ConditionalFact]
public void Can_register_scoped_service_with_concrete_implementation()
{
Expand Down

0 comments on commit cc538cd

Please sign in to comment.