Skip to content

Fix FromKeyedServicesAttribute and FromServicesAttribute to support derived types across all generators #63114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ private void ProcessEndpointParameterSource(Endpoint endpoint, ISymbol symbol, I
else if (attributes.HasAttributeImplementingInterface(wellKnownTypes.Get(WellKnownType.Microsoft_AspNetCore_Http_Metadata_IFromServiceMetadata)))
{
Source = EndpointParameterSource.Service;
if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
{
var ___location = endpoint.Operation.Syntax.GetLocation();
endpoint.Diagnostics.Add(Diagnostic.Create(DiagnosticDescriptors.KeyedAndNotKeyedServiceAttributesNotSupported, ___location));
}
}
else if (attributes.TryGetAttribute(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
else if (attributes.TryGetAttributeInheritingFrom(wellKnownTypes.Get(WellKnownType.Microsoft_Extensions_DependencyInjection_FromKeyedServicesAttribute), out var keyedServicesAttribute))
{
Source = EndpointParameterSource.KeyedService;
var constructorArgument = keyedServicesAttribute.ConstructorArguments.FirstOrDefault();
Expand Down
4 changes: 2 additions & 2 deletions src/Http/Http.Extensions/src/RequestDelegateFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -803,15 +803,15 @@ private static Expression CreateArgument(ParameterInfo parameter, RequestDelegat
}
else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)))
{
if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is not null)
if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is not null)
{
throw new NotSupportedException(
$"The {nameof(FromKeyedServicesAttribute)} is not supported on parameters that are also annotated with {nameof(IFromServiceMetadata)}.");
}
factoryContext.TrackedParameters.Add(parameter.Name, RequestDelegateFactoryConstants.ServiceAttribute);
return BindParameterFromService(parameter, factoryContext);
}
else if (parameterCustomAttributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } keyedServicesAttribute)
else if (parameterCustomAttributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute keyedServicesAttribute)
{
if (factoryContext.ServiceProviderIsService is not IServiceProviderIsKeyedService)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,23 @@ public async Task RequestDelegateGeneratesCompilableCodeForKeyedServiceInNamespa
await VerifyResponseBodyAsync(httpContext, "To be or not to be…");
}

[Fact]
public async Task SupportsDerivedFromKeyedServicesAttribute()
{
var source = """
app.MapGet("/", (HttpContext context, [CustomFromKeyedServices("customKey")] TestService arg) => context.Items["arg"] = arg);
""";
var (_, compilation) = await RunGeneratorAsync(source);
var myOriginalService = new TestService();
var serviceProvider = CreateServiceProvider((serviceCollection) => serviceCollection.AddKeyedSingleton("customKey", myOriginalService));
var endpoint = GetEndpointFromCompilation(compilation, serviceProvider: serviceProvider);

var httpContext = CreateHttpContext(serviceProvider);
await endpoint.RequestDelegate(httpContext);

Assert.Same(myOriginalService, httpContext.Items["arg"]);
}

private class MockServiceProvider : IServiceProvider, ISupportRequiredService
{
public object GetService(Type serviceType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public class CustomFromBodyAttribute : Attribute, IFromBodyMetadata
public bool AllowEmpty { get; set; }
}

public class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute
{
public CustomFromKeyedServicesAttribute(object key) : base(key) { }
}

public enum TodoStatus
{
Trap, // A trap for Enum.TryParse<T>!
Expand Down
2 changes: 1 addition & 1 deletion src/Mvc/Mvc.Abstractions/src/ModelBinding/BindingInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public Type? BinderType
}

// Keyed services
if (attributes.OfType<FromKeyedServicesAttribute>().FirstOrDefault() is { } fromKeyedServicesAttribute)
if (attributes.FirstOrDefault(a => typeof(FromKeyedServicesAttribute).IsAssignableFrom(a.GetType())) is FromKeyedServicesAttribute fromKeyedServicesAttribute)
{
if (bindingInfo.BindingSource != null)
{
Expand Down
27 changes: 27 additions & 0 deletions src/Mvc/Mvc.Abstractions/test/ModelBinding/BindingInfoTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,31 @@ public void GetBindingInfo_ThrowsWhenWithFromKeyedServicesAttributeAndIFromServi
// Act and Assert
Assert.Throws<NotSupportedException>(() => BindingInfo.GetBindingInfo(attributes, modelMetadata));
}

[Fact]
public void GetBindingInfo_WithDerivedFromKeyedServicesAttribute()
{
// Arrange
var key = new object();
var attributes = new object[]
{
new CustomFromKeyedServicesAttribute(key),
};
var modelType = typeof(Guid);
var provider = new TestModelMetadataProvider();
var modelMetadata = provider.GetMetadataForType(modelType);

// Act
var bindingInfo = BindingInfo.GetBindingInfo(attributes, modelMetadata);

// Assert
Assert.NotNull(bindingInfo);
Assert.Same(BindingSource.Services, bindingInfo.BindingSource);
Assert.Same(key, bindingInfo.ServiceKey);
}

private class CustomFromKeyedServicesAttribute : FromKeyedServicesAttribute
{
public CustomFromKeyedServicesAttribute(object key) : base(key) { }
}
}
52 changes: 52 additions & 0 deletions src/Shared/RoslynUtils/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,46 @@ public static bool TryGetAttributeImplementingInterface(this ImmutableArray<Attr
return false;
}

public static bool HasAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType)
{
return symbol.TryGetAttributeInheritingFrom(baseType, out var _);
}

public static bool TryGetAttributeInheritingFrom(this ISymbol symbol, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
{
foreach (var attributeData in symbol.GetAttributes())
{
if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType))
{
matchedAttribute = attributeData;
return true;
}
}

matchedAttribute = null;
return false;
}

public static bool HasAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType)
{
return attributes.TryGetAttributeInheritingFrom(baseType, out var _);
}

public static bool TryGetAttributeInheritingFrom(this ImmutableArray<AttributeData> attributes, INamedTypeSymbol baseType, [NotNullWhen(true)] out AttributeData? matchedAttribute)
{
foreach (var attributeData in attributes)
{
if (attributeData.AttributeClass is not null && attributeData.AttributeClass.InheritsFrom(baseType))
{
matchedAttribute = attributeData;
return true;
}
}

matchedAttribute = null;
return false;
}

public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
{
foreach (var t in type.AllInterfaces)
Expand All @@ -134,6 +174,18 @@ public static bool Implements(this ITypeSymbol type, ITypeSymbol interfaceType)
return false;
}

public static bool InheritsFrom(this ITypeSymbol type, ITypeSymbol baseType)
{
foreach (var t in type.GetThisAndBaseTypes())
{
if (SymbolEqualityComparer.Default.Equals(t, baseType))
{
return true;
}
}
return false;
}

public static bool IsType(this INamedTypeSymbol type, string typeName, SemanticModel semanticModel)
=> SymbolEqualityComparer.Default.Equals(type, semanticModel.Compilation.GetTypeByMetadataName(typeName));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IServiceProvider
markedParameter = true;
MarkServiceParameter(index);
}
else if (attribute is FromKeyedServicesAttribute keyedServicesAttribute)
else if (typeof(FromKeyedServicesAttribute).IsAssignableFrom(attribute.GetType()) && attribute is FromKeyedServicesAttribute keyedServicesAttribute)
{
ThrowIfMarked(markedParameter);
markedParameter = true;
Expand Down
3 changes: 2 additions & 1 deletion src/Validation/gen/Extensions/ITypeSymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Collections.Immutable;
using System.Linq;
using Microsoft.AspNetCore.Analyzers.RouteEmbeddedLanguage.Infrastructure;
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
using Microsoft.CodeAnalysis;

Expand Down Expand Up @@ -136,7 +137,7 @@ internal static bool IsServiceParameter(this IParameterSymbol parameter, INamedT
return parameter.GetAttributes().Any(attr =>
attr.AttributeClass is not null &&
(attr.AttributeClass.ImplementsInterface(fromServiceMetadataSymbol) ||
SymbolEqualityComparer.Default.Equals(attr.AttributeClass, fromKeyedServiceAttributeSymbol)));
attr.AttributeClass.InheritsFrom(fromKeyedServiceAttributeSymbol)));
}

/// <summary>
Expand Down
Loading