diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs index 301e1d524d71..80ee927a64f4 100644 --- a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs @@ -147,13 +147,13 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata))) { Source = EndpointParameterSource.Service; - if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) { var location = endpoint.Operation.Syntax.GetLocation(); endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location)); } } - else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + else if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) { Source = EndpointParameterSource.KeyedService; var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault(); diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index aaebe4179806..5dc5104401b3 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -803,7 +803,7 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat } else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType))) { - if (parameterCustomAttributes.OfType().FirstOrDefault() is not null) + if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is not null) { throw new NotSupportedException( $"The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}."); @@ -811,7 +811,7 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute); return BindParameterFromService(parameter, factoryContext); } - else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } keyedServicesAttribute) + else if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute keyedServicesAttribute) { if (factoryContext.ServiceProviderIsService is not IServiceProviderIsKeyedService) { diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs index 54844d00f5bc..43cf5bbab3f4 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs @@ -252,6 +252,23 @@ public async Task RequestDelegateGeneratesCompilableCodeForKeyedServiceInNamespa await VerifyResponseBodyAsync(httpContext, "To be or not to be…"); } + [Fact] + public async Task SupportsDerivedFromKeyedServicesAttribute() + { + var source = """ +app.MapGet("/", (HttpContext context, [CustomFromKeyedServices("customKey")] TestService arg) => context.Items["arg"] = arg); +"""; + var (_, compilation) = await RunGeneratorAsync(source); + var myOriginalService = new TestService(); + var serviceProvider = CreateServiceProvider((serviceCollection) => serviceCollection.AddKeyedSingleton("customKey", myOriginalService)); + var endpoint = GetEndpointFromCompilation(compilation, serviceProvider: serviceProvider); + + var httpContext = CreateHttpContext(serviceProvider); + await endpoint.RequestDelegate(httpContext); + + Assert.Same(myOriginalService, httpContext.Items["arg"]); + } + private class MockServiceProvider : IServiceProvider, ISupportRequiredService { public object GetService(Type serviceType) diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs index adc9de0999af..fa866fb4c13f 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs @@ -110,6 +110,11 @@ public class CustomFromBodyAttribute : Attribute, IFromBodyMetadata public bool AllowEmpty { get; set; } } +public class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute +{ + public CustomFromKeyedServicesAttribute(object key) : base(key) { } +} + public enum TodoStatus { Trap, // A trap for Enum.TryParse! diff --git a/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs b/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs index 1e4f8f3ca836..4792f5f3468a 100644 --- a/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs +++ b/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs @@ -177,7 +177,7 @@ public Type? BinderType } // Keyed services - if (attributes.OfType().FirstOrDefault() is { } fromKeyedServicesAttribute) + if (attributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute fromKeyedServicesAttribute) { if (bindingInfo.BindingSource != null) { diff --git a/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs b/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs index 00d42c3b12c4..04d474e106e2 100644 --- a/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs +++ b/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs @@ -326,4 +326,31 @@ public void GetBindingInfo_ThrowsWhenWithFromKeyedServicesAttributeAndIFromServi // Act and Assert Assert.Throws(() => BindingInfo.GetBindingInfo(attributes, modelMetadata)); } + + [Fact] + public void GetBindingInfo_WithDerivedFromKeyedServicesAttribute() + { + // Arrange + var key = new object(); + var attributes = new object[] + { + new CustomFromKeyedServicesAttribute(key), + }; + var modelType = typeof(Guid); + var provider = new TestModelMetadataProvider(); + var modelMetadata = provider.GetMetadataForType(modelType); + + // Act + var bindingInfo = BindingInfo.GetBindingInfo(attributes, modelMetadata); + + // Assert + Assert.NotNull(bindingInfo); + Assert.Same(BindingSource.Services, bindingInfo.BindingSource); + Assert.Same(key, bindingInfo.ServiceKey); + } + + private class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute + { + public CustomFromKeyedServicesAttribute(object key) : base(key) { } + } } diff --git a/src/Shared/RoslynUtils/SymbolExtensions.cs b/src/Shared/RoslynUtils/SymbolExtensions.cs index cb41458638fc..63bd638d4216 100644 --- a/src/Shared/RoslynUtils/SymbolExtensions.cs +++ b/src/Shared/RoslynUtils/SymbolExtensions.cs @@ -122,6 +122,46 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray attributes, INamedTypeSymbol baseType) + { + return attributes.TryGetAttributeInheritingFrom(baseType, out var _); + } + + public static bool TryGetAttributeInheritingFrom(this ImmutableArray attributes, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute) + { + foreach (var attributeData in attributes) + { + if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType)) + { + matchedAttribute = attributeData; + return true; + } + } + + matchedAttribute = null; + return false; + } + public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType) { foreach (var t in type.AllInterfaces) @@ -134,6 +174,18 @@ public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType) return false; } + public static bool InheritsFrom(this ITypeSymbol type, ITypeSymbol baseType) + { + foreach (var t in type.GetThisAndBaseTypes()) + { + if (SymbolEqualityComparer.Default.Equals(t, baseType)) + { + return true; + } + } + return false; + } + public static bool IsType(this INamedTypeSymbol type, string typeName, SemanticModel semanticModel) => SymbolEqualityComparer.Default.Equals(type, semanticModel.Compilation.GetTypeByMetadataName(typeName)); diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index f32e5a9afc90..f93652c71f3f 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -88,7 +88,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider markedParameter = true; MarkServiceParameter(index); } - else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute) + else if (typeof(FromKeyedServicesAttribute).IsAssignableFrom(attribute.GetType()) && attribute is FromKeyedServicesAttribute keyedServicesAttribute) { ThrowIfMarked(markedParameter); markedParameter = true; diff --git a/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs b/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs index 408ec7defb89..9aab8c028013 100644 --- a/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs +++ b/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs @@ -3,6 +3,7 @@ using System.Collections.Immutable; using System.Linq; +using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.CodeAnalysis; @@ -136,7 +137,7 @@ internal static bool IsServiceParameter(this IParameterSymbol parameter, INamedT return parameter.GetAttributes().Any(attr => attr.AttributeClass is not null && (attr.AttributeClass.ImplementsInterface(fromServiceMetadataSymbol) || - SymbolEqualityComparer.Default.Equals(attr.AttributeClass, fromKeyedServiceAttributeSymbol))); + attr.AttributeClass.InheritsFrom(fromKeyedServiceAttributeSymbol))); } ///