Skip to content

Commit ec72a16

Browse files
Copilotcaptainsafia
andcommitted
Update FromKeyedServicesAttribute support for derived types in generators and validation
- Add InheritsFrom extension method to check class inheritance in SymbolExtensions - Add TryGetAttributeInheritingFrom methods to detect attributes that inherit from base types - Update IsServiceParameter in ITypeSymbolExtensions to use inheritance checking - Update EndpointParameter.cs to use TryGetAttributeInheritingFrom for FromKeyedServicesAttribute - Add CustomFromKeyedServicesAttribute test type and SupportsDerivedFromKeyedServicesAttribute test - All builds and tests pass successfully Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com>
1 parent d0bef45 commit ec72a16

File tree

5 files changed

+78
-3
lines changed

5 files changed

+78
-3
lines changed

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,13 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I
147147
else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata)))
148148
{
149149
Source = EndpointParameterSource.Service;
150-
if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
150+
if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
151151
{
152152
var location = endpoint.Operation.Syntax.GetLocation();
153153
endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location));
154154
}
155155
}
156-
else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
156+
else if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
157157
{
158158
Source = EndpointParameterSource.KeyedService;
159159
var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault();

src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,23 @@ public async Task RequestDelegateGeneratesCompilableCodeForKeyedServiceInNamespa
252252
await VerifyResponseBodyAsync(httpContext, "To be or not to be…");
253253
}
254254

255+
[Fact]
256+
public async Task SupportsDerivedFromKeyedServicesAttribute()
257+
{
258+
var source = """
259+
app.MapGet("/", (HttpContext context, [CustomFromKeyedServices("customKey")] TestService arg) => context.Items["arg"] = arg);
260+
""";
261+
var (_, compilation) = await RunGeneratorAsync(source);
262+
var myOriginalService = new TestService();
263+
var serviceProvider = CreateServiceProvider((serviceCollection) => serviceCollection.AddKeyedSingleton("customKey", myOriginalService));
264+
var endpoint = GetEndpointFromCompilation(compilation, serviceProvider: serviceProvider);
265+
266+
var httpContext = CreateHttpContext(serviceProvider);
267+
await endpoint.RequestDelegate(httpContext);
268+
269+
Assert.Same(myOriginalService, httpContext.Items["arg"]);
270+
}
271+
255272
private class MockServiceProvider : IServiceProvider, ISupportRequiredService
256273
{
257274
public object GetService(Type serviceType)

src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ public class CustomFromBodyAttribute : Attribute, IFromBodyMetadata
110110
public bool AllowEmpty { get; set; }
111111
}
112112

113+
public class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute
114+
{
115+
public CustomFromKeyedServicesAttribute(object key) : base(key) { }
116+
}
117+
113118
public enum TodoStatus
114119
{
115120
Trap, // A trap for Enum.TryParse<T>!

src/Shared/RoslynUtils/SymbolExtensions.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,46 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray<Attr
122122
return false;
123123
}
124124

125+
public static bool HasAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType)
126+
{
127+
return symbol.TryGetAttributeInheritingFrom(baseType, out var _);
128+
}
129+
130+
public static bool TryGetAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
131+
{
132+
foreach (var attributeData in symbol.GetAttributes())
133+
{
134+
if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType))
135+
{
136+
matchedAttribute = attributeData;
137+
return true;
138+
}
139+
}
140+
141+
matchedAttribute = null;
142+
return false;
143+
}
144+
145+
public static bool HasAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType)
146+
{
147+
return attributes.TryGetAttributeInheritingFrom(baseType, out var _);
148+
}
149+
150+
public static bool TryGetAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
151+
{
152+
foreach (var attributeData in attributes)
153+
{
154+
if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType))
155+
{
156+
matchedAttribute = attributeData;
157+
return true;
158+
}
159+
}
160+
161+
matchedAttribute = null;
162+
return false;
163+
}
164+
125165
public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
126166
{
127167
foreach (var t in type.AllInterfaces)
@@ -134,6 +174,18 @@ public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
134174
return false;
135175
}
136176

177+
public static bool InheritsFrom(this ITypeSymbol type, ITypeSymbol baseType)
178+
{
179+
foreach (var t in type.GetThisAndBaseTypes())
180+
{
181+
if (SymbolEqualityComparer.Default.Equals(t, baseType))
182+
{
183+
return true;
184+
}
185+
}
186+
return false;
187+
}
188+
137189
public static bool IsType(this INamedTypeSymbol type, string typeName, SemanticModel semanticModel)
138190
=> SymbolEqualityComparer.Default.Equals(type, semanticModel.Compilation.GetTypeByMetadataName(typeName));
139191

src/Validation/gen/Extensions/ITypeSymbolExtensions.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Collections.Immutable;
55
using System.Linq;
6+
using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure;
67
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
78
using Microsoft.CodeAnalysis;
89

@@ -136,7 +137,7 @@ internal static bool IsServiceParameter(this IParameterSymbol parameter, INamedT
136137
return parameter.GetAttributes().Any(attr =>
137138
attr.AttributeClass is not null &&
138139
(attr.AttributeClass.ImplementsInterface(fromServiceMetadataSymbol) ||
139-
SymbolEqualityComparer.Default.Equals(attr.AttributeClass, fromKeyedServiceAttributeSymbol)));
140+
attr.AttributeClass.InheritsFrom(fromKeyedServiceAttributeSymbol)));
140141
}
141142

142143
/// <summary>

0 commit comments

Comments
 (0)