From f1baba6b6041340f4d0d0e8f5387714f4b56a851 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:47:16 +0000 Subject: [PATCH 1/3] Initial plan From d0bef452c7b94e3f8f5b83d12279518d1f5e56cf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 4 Aug 2025 22:09:17 +0000 Subject: [PATCH 2/3] Fix FromKeyedServicesAttribute and FromServicesAttribute to support derived types - Replace OfType() with IsAssignableFrom checks in BindingInfo.cs - Replace OfType() with IsAssignableFrom checks in RequestDelegateFactory.cs - Update SignalR HubMethodDescriptor pattern matching to handle derived FromKeyedServicesAttribute types - Add test for derived FromKeyedServicesAttribute detection Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com> --- .../src/RequestDelegateFactory.cs | 4 +-- .../src/ModelBinding/BindingInfo.cs | 2 +- .../test/ModelBinding/BindingInfoTest.cs | 27 +++++++++++++++++++ .../Core/src/Internal/HubMethodDescriptor.cs | 2 +- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index aaebe4179806..5dc5104401b3 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -803,7 +803,7 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat } else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType))) { - if (parameterCustomAttributes.OfType().FirstOrDefault() is not null) + if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is not null) { throw new NotSupportedException( $"The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}."); @@ -811,7 +811,7 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute); return BindParameterFromService(parameter, factoryContext); } - else if (parameterCustomAttributes.OfType().FirstOrDefault() is { } keyedServicesAttribute) + else if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute keyedServicesAttribute) { if (factoryContext.ServiceProviderIsService is not IServiceProviderIsKeyedService) { diff --git a/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs b/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs index 1e4f8f3ca836..4792f5f3468a 100644 --- a/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs +++ b/src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs @@ -177,7 +177,7 @@ public Type? BinderType } // Keyed services - if (attributes.OfType().FirstOrDefault() is { } fromKeyedServicesAttribute) + if (attributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute fromKeyedServicesAttribute) { if (bindingInfo.BindingSource != null) { diff --git a/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs b/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs index 00d42c3b12c4..04d474e106e2 100644 --- a/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs +++ b/src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs @@ -326,4 +326,31 @@ public void GetBindingInfo_ThrowsWhenWithFromKeyedServicesAttributeAndIFromServi // Act and Assert Assert.Throws(() => BindingInfo.GetBindingInfo(attributes, modelMetadata)); } + + [Fact] + public void GetBindingInfo_WithDerivedFromKeyedServicesAttribute() + { + // Arrange + var key = new object(); + var attributes = new object[] + { + new CustomFromKeyedServicesAttribute(key), + }; + var modelType = typeof(Guid); + var provider = new TestModelMetadataProvider(); + var modelMetadata = provider.GetMetadataForType(modelType); + + // Act + var bindingInfo = BindingInfo.GetBindingInfo(attributes, modelMetadata); + + // Assert + Assert.NotNull(bindingInfo); + Assert.Same(BindingSource.Services, bindingInfo.BindingSource); + Assert.Same(key, bindingInfo.ServiceKey); + } + + private class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute + { + public CustomFromKeyedServicesAttribute(object key) : base(key) { } + } } diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index f32e5a9afc90..f93652c71f3f 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -88,7 +88,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider markedParameter = true; MarkServiceParameter(index); } - else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute) + else if (typeof(FromKeyedServicesAttribute).IsAssignableFrom(attribute.GetType()) && attribute is FromKeyedServicesAttribute keyedServicesAttribute) { ThrowIfMarked(markedParameter); markedParameter = true; From ec72a1632639c46fe3e2fb95ea70b10a454efafa Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 5 Aug 2025 00:40:32 +0000 Subject: [PATCH 3/3] Update FromKeyedServicesAttribute support for derived types in generators and validation - Add InheritsFrom extension method to check class inheritance in SymbolExtensions - Add TryGetAttributeInheritingFrom methods to detect attributes that inherit from base types - Update IsServiceParameter in ITypeSymbolExtensions to use inheritance checking - Update EndpointParameter.cs to use TryGetAttributeInheritingFrom for FromKeyedServicesAttribute - Add CustomFromKeyedServicesAttribute test type and SupportsDerivedFromKeyedServicesAttribute test - All builds and tests pass successfully Co-authored-by: captainsafia <1857993+captainsafia@users.noreply.github.com> --- .../EndpointParameter.cs | 4 +- ...equestDelegateCreationTests.KeyServices.cs | 17 ++++++ .../RequestDelegateGenerator/SharedTypes.cs | 5 ++ src/Shared/RoslynUtils/SymbolExtensions.cs | 52 +++++++++++++++++++ .../gen/Extensions/ITypeSymbolExtensions.cs | 3 +- 5 files changed, 78 insertions(+), 3 deletions(-) diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs index 301e1d524d71..80ee927a64f4 100644 --- a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.RequestDelegateGenerator/StaticRouteHandlerModel/EndpointParameter.cs @@ -147,13 +147,13 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata))) { Source = EndpointParameterSource.Service; - if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) { var location = endpoint.Operation.Syntax.GetLocation(); endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, location)); } } - else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) + else if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute)) { Source = EndpointParameterSource.KeyedService; var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault(); diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs index 54844d00f5bc..43cf5bbab3f4 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.KeyServices.cs @@ -252,6 +252,23 @@ public async Task RequestDelegateGeneratesCompilableCodeForKeyedServiceInNamespa await VerifyResponseBodyAsync(httpContext, "To be or not to be…"); } + [Fact] + public async Task SupportsDerivedFromKeyedServicesAttribute() + { + var source = """ +app.MapGet("/", (HttpContext context, [CustomFromKeyedServices("customKey")] TestService arg) => context.Items["arg"] = arg); +"""; + var (_, compilation) = await RunGeneratorAsync(source); + var myOriginalService = new TestService(); + var serviceProvider = CreateServiceProvider((serviceCollection) => serviceCollection.AddKeyedSingleton("customKey", myOriginalService)); + var endpoint = GetEndpointFromCompilation(compilation, serviceProvider: serviceProvider); + + var httpContext = CreateHttpContext(serviceProvider); + await endpoint.RequestDelegate(httpContext); + + Assert.Same(myOriginalService, httpContext.Items["arg"]); + } + private class MockServiceProvider : IServiceProvider, ISupportRequiredService { public object GetService(Type serviceType) diff --git a/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs b/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs index adc9de0999af..fa866fb4c13f 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateGenerator/SharedTypes.cs @@ -110,6 +110,11 @@ public class CustomFromBodyAttribute : Attribute, IFromBodyMetadata public bool AllowEmpty { get; set; } } +public class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute +{ + public CustomFromKeyedServicesAttribute(object key) : base(key) { } +} + public enum TodoStatus { Trap, // A trap for Enum.TryParse! diff --git a/src/Shared/RoslynUtils/SymbolExtensions.cs b/src/Shared/RoslynUtils/SymbolExtensions.cs index cb41458638fc..63bd638d4216 100644 --- a/src/Shared/RoslynUtils/SymbolExtensions.cs +++ b/src/Shared/RoslynUtils/SymbolExtensions.cs @@ -122,6 +122,46 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray attributes, INamedTypeSymbol baseType) + { + return attributes.TryGetAttributeInheritingFrom(baseType, out var _); + } + + public static bool TryGetAttributeInheritingFrom(this ImmutableArray attributes, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute) + { + foreach (var attributeData in attributes) + { + if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType)) + { + matchedAttribute = attributeData; + return true; + } + } + + matchedAttribute = null; + return false; + } + public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType) { foreach (var t in type.AllInterfaces) @@ -134,6 +174,18 @@ public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType) return false; } + public static bool InheritsFrom(this ITypeSymbol type, ITypeSymbol baseType) + { + foreach (var t in type.GetThisAndBaseTypes()) + { + if (SymbolEqualityComparer.Default.Equals(t, baseType)) + { + return true; + } + } + return false; + } + public static bool IsType(this INamedTypeSymbol type, string typeName, SemanticModel semanticModel) => SymbolEqualityComparer.Default.Equals(type, semanticModel.Compilation.GetTypeByMetadataName(typeName)); diff --git a/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs b/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs index 408ec7defb89..9aab8c028013 100644 --- a/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs +++ b/src/Validation/gen/Extensions/ITypeSymbolExtensions.cs @@ -3,6 +3,7 @@ using System.Collections.Immutable; using System.Linq; +using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure; using Microsoft.AspNetCore.App.Analyzers.Infrastructure; using Microsoft.CodeAnalysis; @@ -136,7 +137,7 @@ internal static bool IsServiceParameter(this IParameterSymbol parameter, INamedT return parameter.GetAttributes().Any(attr => attr.AttributeClass is not null && (attr.AttributeClass.ImplementsInterface(fromServiceMetadataSymbol) || - SymbolEqualityComparer.Default.Equals(attr.AttributeClass, fromKeyedServiceAttributeSymbol))); + attr.AttributeClass.InheritsFrom(fromKeyedServiceAttributeSymbol))); } ///