Skip to content

Commit 0e3d342

Browse files
Fix FromKeyedServicesAttribute and FromServicesAttribute to support derived types across all generators (#63114)
* Initial plan * Fix FromKeyedServicesAttribute and FromServicesAttribute to support derived types - Replace OfType<FromKeyedServicesAttribute>() with IsAssignableFrom checks in BindingInfo.cs - Replace OfType<FromKeyedServicesAttribute>() with IsAssignableFrom checks in RequestDelegateFactory.cs - Update SignalR HubMethodDescriptor pattern matching to handle derived FromKeyedServicesAttribute types - Add test for derived FromKeyedServicesAttribute detection Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com> * Update FromKeyedServicesAttribute support for derived types in generators and validation - Add InheritsFrom extension method to check class inheritance in SymbolExtensions - Add TryGetAttributeInheritingFrom methods to detect attributes that inherit from base types - Update IsServiceParameter in ITypeSymbolExtensions to use inheritance checking - Update EndpointParameter.cs to use TryGetAttributeInheritingFrom for FromKeyedServicesAttribute - Add CustomFromKeyedServicesAttribute test type and SupportsDerivedFromKeyedServicesAttribute test - All builds and tests pass successfully Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com> * Feedback --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com> Co-authored-by: Safia Abdalla <safia@microsoft.com>
1 parent 8c054f4 commit 0e3d342

File tree

8 files changed

+97
-6
lines changed

8 files changed

+97
-6
lines changed

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,13 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I
147147
else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata)))
148148
{
149149
Source = EndpointParameterSource.Service;
150-
if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
150+
if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
151151
{
152152
var location = endpoint.Operation.Syntax.GetLocation();
153153
endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location));
154154
}
155155
}
156-
else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
156+
else if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
157157
{
158158
Source = EndpointParameterSource.KeyedService;
159159
var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault();

src/Http/Http.Extensions/src/RequestDelegateFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,15 +804,15 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat
804804
}
805805
else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
806806
{
807-
if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is not null)
807+
if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is not null)
808808
{
809809
throw new NotSupportedException(
810810
$"The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}.");
811811
}
812812
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute);
813813
return BindParameterFromService(parameter, factoryContext);
814814
}
815-
else if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } keyedServicesAttribute)
815+
else if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute keyedServicesAttribute)
816816
{
817817
if (factoryContext.ServiceProviderIsService is not IServiceProviderIsKeyedService)
818818
{

src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,23 @@ public async Task RequestDelegateGeneratesCompilableCodeForKeyedServiceInNamespa
252252
await VerifyResponseBodyAsync(httpContext, "To be or not to be…");
253253
}
254254

255+
[Fact]
256+
public async Task SupportsDerivedFromKeyedServicesAttribute()
257+
{
258+
var source = """
259+
app.MapGet("/", (HttpContext context, [CustomFromKeyedServices("customKey")] TestService arg) => context.Items["arg"] = arg);
260+
""";
261+
var (_, compilation) = await RunGeneratorAsync(source);
262+
var myOriginalService = new TestService();
263+
var serviceProvider = CreateServiceProvider((serviceCollection) => serviceCollection.AddKeyedSingleton("customKey", myOriginalService));
264+
var endpoint = GetEndpointFromCompilation(compilation, serviceProvider: serviceProvider);
265+
266+
var httpContext = CreateHttpContext(serviceProvider);
267+
await endpoint.RequestDelegate(httpContext);
268+
269+
Assert.Same(myOriginalService, httpContext.Items["arg"]);
270+
}
271+
255272
private class MockServiceProvider : IServiceProvider, ISupportRequiredService
256273
{
257274
public object GetService(Type serviceType)

src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public class CustomFromBodyAttribute : Attribute, IFromBodyMetadata
110110
public bool AllowEmpty { get; set; }
111111
}
112112

113+
public class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute
114+
{
115+
public CustomFromKeyedServicesAttribute(object key) : base(key) { }
116+
}
117+
113118
public enum TodoStatus
114119
{
115120
Trap, // A trap for Enum.TryParse<T>!

src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ public Type? BinderType
177177
}
178178

179179
// Keyed services
180-
if (attributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } fromKeyedServicesAttribute)
180+
if (attributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute fromKeyedServicesAttribute)
181181
{
182182
if (bindingInfo.BindingSource != null)
183183
{

src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,31 @@ public void GetBindingInfo_ThrowsWhenWithFromKeyedServicesAttributeAndIFromServi
326326
// Act and Assert
327327
Assert.Throws<NotSupportedException>(() => BindingInfo.GetBindingInfo(attributes, modelMetadata));
328328
}
329+
330+
[Fact]
331+
public void GetBindingInfo_WithDerivedFromKeyedServicesAttribute()
332+
{
333+
// Arrange
334+
var key = new object();
335+
var attributes = new object[]
336+
{
337+
new CustomFromKeyedServicesAttribute(key),
338+
};
339+
var modelType = typeof(Guid);
340+
var provider = new TestModelMetadataProvider();
341+
var modelMetadata = provider.GetMetadataForType(modelType);
342+
343+
// Act
344+
var bindingInfo = BindingInfo.GetBindingInfo(attributes, modelMetadata);
345+
346+
// Assert
347+
Assert.NotNull(bindingInfo);
348+
Assert.Same(BindingSource.Services, bindingInfo.BindingSource);
349+
Assert.Same(key, bindingInfo.ServiceKey);
350+
}
351+
352+
private class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute
353+
{
354+
public CustomFromKeyedServicesAttribute(object key) : base(key) { }
355+
}
329356
}

src/Shared/RoslynUtils/SymbolExtensions.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,36 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray<Attr
141141
return false;
142142
}
143143

144+
public static bool HasAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType)
145+
{
146+
return symbol.TryGetAttributeInheritingFrom(baseType, out var _);
147+
}
148+
149+
public static bool TryGetAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
150+
{
151+
return symbol.GetAttributes().TryGetAttributeInheritingFrom(baseType, out matchedAttribute);
152+
}
153+
154+
public static bool HasAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType)
155+
{
156+
return attributes.TryGetAttributeInheritingFrom(baseType, out var _);
157+
}
158+
159+
public static bool TryGetAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
160+
{
161+
foreach (var attributeData in attributes)
162+
{
163+
if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType))
164+
{
165+
matchedAttribute = attributeData;
166+
return true;
167+
}
168+
}
169+
170+
matchedAttribute = null;
171+
return false;
172+
}
173+
144174
public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
145175
{
146176
foreach (var t in type.AllInterfaces)
@@ -153,6 +183,18 @@ public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
153183
return false;
154184
}
155185

186+
public static bool InheritsFrom(this ITypeSymbol type, ITypeSymbol baseType)
187+
{
188+
foreach (var t in type.GetThisAndBaseTypes())
189+
{
190+
if (SymbolEqualityComparer.Default.Equals(t, baseType))
191+
{
192+
return true;
193+
}
194+
}
195+
return false;
196+
}
197+
156198
public static bool IsType(this INamedTypeSymbol type, string typeName, SemanticModel semanticModel)
157199
=> SymbolEqualityComparer.Default.Equals(type, semanticModel.Compilation.GetTypeByMetadataName(typeName));
158200

src/Validation/gen/Extensions/ITypeSymbolExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ internal static bool IsServiceParameter(this IParameterSymbol parameter, INamedT
137137
return parameter.GetAttributes().Any(attr =>
138138
attr.AttributeClass is not null &&
139139
(attr.AttributeClass.ImplementsInterface(fromServiceMetadataSymbol) ||
140-
SymbolEqualityComparer.Default.Equals(attr.AttributeClass, fromKeyedServiceAttributeSymbol)));
140+
attr.AttributeClass.InheritsFrom(fromKeyedServiceAttributeSymbol)));
141141
}
142142

143143
/// <summary>

0 commit comments

Comments
 (0)