Skip to content

[Blazor] Support persisting component state on enhanced navigation #62824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .azure/pipelines/components-e2e-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,17 @@ jobs:
exit 1
fi
displayName: Run E2E tests
env:
DOTNET_EnableAVX512: 0
- script: .dotnet/dotnet test ./src/Components/test/E2ETest -c $(BuildConfiguration) --no-build --filter 'Quarantined=true' -p:RunQuarantinedTests=true
-p:VsTestUseMSBuildOutput=false
--logger:"trx%3BLogFileName=Microsoft.AspNetCore.Components.E2ETests.trx"
--logger:"html%3BLogFileName=Microsoft.AspNetCore.Components.E2ETests.html"
--results-directory $(Build.SourcesDirectory)/artifacts/TestResults/$(BuildConfiguration)/Quarantined
displayName: Run Quarantined E2E tests
continueOnError: true
env:
DOTNET_EnableAVX512: 0
- task: PublishTestResults@2
displayName: Publish E2E Test Results
inputs:
Expand Down
27 changes: 27 additions & 0 deletions src/Components/Components/src/ComponentSubscriptionKey.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using Microsoft.AspNetCore.Components.Rendering;

namespace Microsoft.AspNetCore.Components.Infrastructure;

[DebuggerDisplay("{GetDebuggerDisplay(),nq}")]
internal readonly struct ComponentSubscriptionKey(ComponentState subscriber, string propertyName) : IEquatable<ComponentSubscriptionKey>
{
public ComponentState Subscriber { get; } = subscriber;

public string PropertyName { get; } = propertyName;

public bool Equals(ComponentSubscriptionKey other)
=> Subscriber == other.Subscriber && PropertyName == other.PropertyName;

public override bool Equals(object? obj)
=> obj is ComponentSubscriptionKey other && Equals(other);

public override int GetHashCode()
=> HashCode.Combine(Subscriber, PropertyName);

private string GetDebuggerDisplay()
=> $"{Subscriber.Component.GetType().Name}.{PropertyName}";
}
50 changes: 47 additions & 3 deletions src/Components/Components/src/PersistentComponentState.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using static Microsoft.AspNetCore.Internal.LinkerFlags;
Expand All @@ -16,24 +17,30 @@ public class PersistentComponentState
private readonly IDictionary<string, byte[]> _currentState;

private readonly List<PersistComponentStateRegistration> _registeredCallbacks;
private readonly List<RestoreComponentStateRegistration> _registeredRestoringCallbacks;

internal PersistentComponentState(
IDictionary<string , byte[]> currentState,
List<PersistComponentStateRegistration> pauseCallbacks)
IDictionary<string, byte[]> currentState,
List<PersistComponentStateRegistration> pauseCallbacks,
List<RestoreComponentStateRegistration> restoringCallbacks)
{
_currentState = currentState;
_registeredCallbacks = pauseCallbacks;
_registeredRestoringCallbacks = restoringCallbacks;
}

internal bool PersistingState { get; set; }

internal void InitializeExistingState(IDictionary<string, byte[]> existingState)
internal RestoreContext CurrentContext { get; private set; } = RestoreContext.InitialValue;

internal void InitializeExistingState(IDictionary<string, byte[]> existingState, RestoreContext context)
{
if (_existingState != null)
{
throw new InvalidOperationException("PersistentComponentState already initialized.");
}
_existingState = existingState ?? throw new ArgumentNullException(nameof(existingState));
CurrentContext = context;
}

/// <summary>
Expand Down Expand Up @@ -68,6 +75,30 @@ public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> call
return new PersistingComponentStateSubscription(_registeredCallbacks, persistenceCallback);
}

/// <summary>
/// Register a callback to restore the state when the application state is being restored.
/// </summary>
/// <param name="callback"> The callback to invoke when the application state is being restored.</param>
/// <param name="options">Options that control the restoration behavior.</param>
/// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
public RestoringComponentStateSubscription RegisterOnRestoring(Action callback, RestoreOptions options)
{
Debug.Assert(CurrentContext != null);
if (CurrentContext.ShouldRestore(options))
{
callback();
}

if (options.AllowUpdates)
{
var registration = new RestoreComponentStateRegistration(callback);
_registeredRestoringCallbacks.Add(registration);
return new RestoringComponentStateSubscription(_registeredRestoringCallbacks, registration);
}

return default;
}

/// <summary>
/// Serializes <paramref name="instance"/> as JSON and persists it under the given <paramref name="key"/>.
/// </summary>
Expand Down Expand Up @@ -214,4 +245,17 @@ private bool TryTake(string key, out byte[]? value)
return false;
}
}

internal void UpdateExistingState(IDictionary<string, byte[]> state, RestoreContext context)
{
ArgumentNullException.ThrowIfNull(state);

if (_existingState == null || _existingState.Count > 0)
{
throw new InvalidOperationException("Cannot update existing state: previous state has not been cleared or state is not initialized.");
}

_existingState = state;
CurrentContext = context;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ namespace Microsoft.AspNetCore.Components.Infrastructure;
public class ComponentStatePersistenceManager
{
private readonly List<PersistComponentStateRegistration> _registeredCallbacks = new();
private readonly List<RestoreComponentStateRegistration> _registeredRestoringCallbacks = new();
private readonly ILogger<ComponentStatePersistenceManager> _logger;

private bool _stateIsPersisted;
private bool _stateIsInitialized;
private readonly PersistentServicesRegistry? _servicesRegistry;
private readonly Dictionary<string, byte[]> _currentState = new(StringComparer.Ordinal);

Expand All @@ -24,7 +26,7 @@ public class ComponentStatePersistenceManager
/// <param name="logger"></param>
public ComponentStatePersistenceManager(ILogger<ComponentStatePersistenceManager> logger)
{
State = new PersistentComponentState(_currentState, _registeredCallbacks);
State = new PersistentComponentState(_currentState, _registeredCallbacks, _registeredRestoringCallbacks);
_logger = logger;
}

Expand Down Expand Up @@ -55,10 +57,38 @@ public ComponentStatePersistenceManager(ILogger<ComponentStatePersistenceManager
/// <param name="store">The <see cref="IPersistentComponentStateStore"/> to restore the application state from.</param>
/// <returns>A <see cref="Task"/> that will complete when the state has been restored.</returns>
public async Task RestoreStateAsync(IPersistentComponentStateStore store)
{
await RestoreStateAsync(store, RestoreContext.InitialValue);
}

/// <summary>
/// Restores the application state.
/// </summary>
/// <param name="store"> The <see cref="IPersistentComponentStateStore"/> to restore the application state from.</param>
/// <param name="context">The <see cref="RestoreContext"/> that provides additional context for the restoration.</param>
/// <returns>A <see cref="Task"/> that will complete when the state has been restored.</returns>
public async Task RestoreStateAsync(IPersistentComponentStateStore store, RestoreContext context)
{
var data = await store.GetPersistedStateAsync();
State.InitializeExistingState(data);
_servicesRegistry?.Restore(State);

if (_stateIsInitialized)
{
if (context != RestoreContext.ValueUpdate)
{
throw new InvalidOperationException("State already initialized.");
}
State.UpdateExistingState(data, context);
foreach (var registration in _registeredRestoringCallbacks)
{
registration.Callback();
}
}
else
{
State.InitializeExistingState(data, context);
_servicesRegistry?.RegisterForPersistence(State);
_stateIsInitialized = true;
}
}

/// <summary>
Expand All @@ -78,9 +108,6 @@ public Task PersistStateAsync(IPersistentComponentStateStore store, Renderer ren

async Task PauseAndPersistState()
{
// Ensure that we register the services before we start persisting the state.
_servicesRegistry?.RegisterForPersistence(State);

State.PersistingState = true;

if (store is IEnumerable<IPersistentComponentStateStore> compositeStore)
Expand Down Expand Up @@ -271,4 +298,5 @@ static async Task<bool> AnyTaskFailed(List<Task<bool>> pendingCallbackTasks)
return true;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ namespace Microsoft.AspNetCore.Components.Infrastructure;
internal sealed class PersistentServicesRegistry
{
private static readonly string _registryKey = typeof(PersistentServicesRegistry).FullName!;
private static readonly RootTypeCache _persistentServiceTypeCache = new RootTypeCache();
private static readonly RootTypeCache _persistentServiceTypeCache = new();

private readonly IServiceProvider _serviceProvider;
private IPersistentServiceRegistration[] _registrations;
private List<PersistingComponentStateSubscription> _subscriptions = [];
private List<(PersistingComponentStateSubscription, RestoringComponentStateSubscription)> _subscriptions = [];
private static readonly ConcurrentDictionary<Type, PropertiesAccessor> _cachedAccessorsByType = new();

static PersistentServicesRegistry()
Expand Down Expand Up @@ -54,7 +54,9 @@ internal void RegisterForPersistence(PersistentComponentState state)
return;
}

var subscriptions = new List<PersistingComponentStateSubscription>(_registrations.Length + 1);
UpdateRegistrations(state);
var subscriptions = new List<(PersistingComponentStateSubscription, RestoringComponentStateSubscription)>(
_registrations.Length + 1);
for (var i = 0; i < _registrations.Length; i++)
{
var registration = _registrations[i];
Expand All @@ -67,20 +69,29 @@ internal void RegisterForPersistence(PersistentComponentState state)
var renderMode = registration.GetRenderModeOrDefault();

var instance = _serviceProvider.GetRequiredService(type);
subscriptions.Add(state.RegisterOnPersisting(() =>
{
PersistInstanceState(instance, type, state);
return Task.CompletedTask;
}, renderMode));
subscriptions.Add((
state.RegisterOnPersisting(() =>
{
PersistInstanceState(instance, type, state);
return Task.CompletedTask;
}, renderMode),
// In order to avoid registering one callback per property, we register a single callback with the most
// permissive options and perform the filtering inside of it.
state.RegisterOnRestoring(() =>
{
RestoreInstanceState(instance, type, state);
}, new RestoreOptions { AllowUpdates = true })));
}

if (RenderMode != null)
{
subscriptions.Add(state.RegisterOnPersisting(() =>
{
state.PersistAsJson(_registryKey, _registrations);
return Task.CompletedTask;
}, RenderMode));
subscriptions.Add((
state.RegisterOnPersisting(() =>
{
state.PersistAsJson(_registryKey, _registrations);
return Task.CompletedTask;
}, RenderMode),
default));
}

_subscriptions = subscriptions;
Expand All @@ -92,7 +103,7 @@ private static void PersistInstanceState(object instance, Type type, PersistentC
var accessors = _cachedAccessorsByType.GetOrAdd(instance.GetType(), static (runtimeType, declaredType) => new PropertiesAccessor(runtimeType, declaredType), type);
foreach (var (key, propertyType) in accessors.KeyTypePairs)
{
var (setter, getter) = accessors.GetAccessor(key);
var (setter, getter, options) = accessors.GetAccessor(key);
var value = getter.GetValue(instance);
if (value != null)
{
Expand All @@ -105,33 +116,12 @@ private static void PersistInstanceState(object instance, Type type, PersistentC
"IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code",
Justification = "Types registered for persistence are preserved in the API call to register them and typically live in assemblies that aren't trimmed.")]
[DynamicDependency(LinkerFlags.JsonSerialized, typeof(PersistentServiceRegistration))]
internal void Restore(PersistentComponentState state)
private void UpdateRegistrations(PersistentComponentState state)
{
if (state.TryTakeFromJson<PersistentServiceRegistration[]>(_registryKey, out var registry) && registry != null)
{
_registrations = ResolveRegistrations(_registrations.Concat(registry));
}

RestoreRegistrationsIfAvailable(state);
}

[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "Types registered for persistence are preserved in the API call to register them and typically live in assemblies that aren't trimmed.")]
private void RestoreRegistrationsIfAvailable(PersistentComponentState state)
{
foreach (var registration in _registrations)
{
var type = ResolveType(registration);
if (type == null)
{
continue;
}

var instance = _serviceProvider.GetService(type);
if (instance != null)
{
RestoreInstanceState(instance, type, state);
}
}
}

[RequiresUnreferencedCode("Calls Microsoft.AspNetCore.Components.PersistentComponentState.TryTakeFromJson(String, Type, out Object)")]
Expand All @@ -140,9 +130,13 @@ private static void RestoreInstanceState(object instance, Type type, PersistentC
var accessors = _cachedAccessorsByType.GetOrAdd(instance.GetType(), static (runtimeType, declaredType) => new PropertiesAccessor(runtimeType, declaredType), type);
foreach (var (key, propertyType) in accessors.KeyTypePairs)
{
var (setter, getter, options) = accessors.GetAccessor(key);
if (!state.CurrentContext.ShouldRestore(options))
{
continue;
}
if (state.TryTakeFromJson(key, propertyType, out var result))
{
var (setter, getter) = accessors.GetAccessor(key);
setter.SetValue(instance, result!);
}
}
Expand All @@ -165,12 +159,12 @@ private sealed class PropertiesAccessor
{
internal const BindingFlags BindablePropertyFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.IgnoreCase;

private readonly Dictionary<string, (PropertySetter, PropertyGetter)> _underlyingAccessors;
private readonly Dictionary<string, (PropertySetter, PropertyGetter, RestoreOptions)> _underlyingAccessors;
private readonly (string, Type)[] _cachedKeysForService;

public PropertiesAccessor([DynamicallyAccessedMembers(LinkerFlags.Component)] Type targetType, Type keyType)
{
_underlyingAccessors = new Dictionary<string, (PropertySetter, PropertyGetter)>(StringComparer.OrdinalIgnoreCase);
_underlyingAccessors = new Dictionary<string, (PropertySetter, PropertyGetter, RestoreOptions)>(StringComparer.OrdinalIgnoreCase);

var keys = new List<(string, Type)>();
foreach (var propertyInfo in GetCandidateBindableProperties(targetType))
Expand Down Expand Up @@ -204,10 +198,16 @@ public PropertiesAccessor([DynamicallyAccessedMembers(LinkerFlags.Component)] Ty
$"The type '{targetType.FullName}' declares a property matching the name '{propertyName}' that is not public. Persistent service properties must be public.");
}

var restoreOptions = new RestoreOptions
{
RestoreBehavior = parameterAttribute.RestoreBehavior,
AllowUpdates = parameterAttribute.AllowUpdates,
};

var propertySetter = new PropertySetter(targetType, propertyInfo);
var propertyGetter = new PropertyGetter(targetType, propertyInfo);

_underlyingAccessors.Add(key, (propertySetter, propertyGetter));
_underlyingAccessors.Add(key, (propertySetter, propertyGetter, restoreOptions));
}

_cachedKeysForService = [.. keys];
Expand Down Expand Up @@ -236,7 +236,7 @@ internal static IEnumerable<PropertyInfo> GetCandidateBindableProperties(
[DynamicallyAccessedMembers(LinkerFlags.Component)] Type targetType)
=> MemberAssignment.GetPropertiesIncludingInherited(targetType, BindablePropertyFlags);

internal (PropertySetter setter, PropertyGetter getter) GetAccessor(string key) =>
internal (PropertySetter setter, PropertyGetter getter, RestoreOptions options) GetAccessor(string key) =>
_underlyingAccessors.TryGetValue(key, out var result) ? result : default;
}

Expand Down
Loading
Loading