Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions Source/Csla/Reflection/ServiceProviderMethodCaller.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,18 @@ public bool TryFindDataPortalMethod<T>([DynamicallyAccessedMembers(DynamicallyAc

var typeOfOperation = typeof(T);

var cacheKey = GetCacheKeyName(targetType, typeOfOperation, criteria, useLegacyMethods);
// Resolve the factory type (if any) up front so it can participate in the cache key.
// In production a business type maps to exactly one factory, but a custom
// IObjectFactoryLoader can resolve the same FactoryTypeName to different types. The
// method cache is process-wide and keyed by business type, so the resolved factory
// type must be part of the key to avoid invoking a delegate compiled for a different
// factory.
var factoryInfo = ObjectFactoryAttribute.GetObjectFactoryAttribute(targetType);
Type? factoryType = null;
if (factoryInfo != null && !TryGetFactoryType(factoryInfo, _applicationContext, throwOnError, out factoryType))
return null;

var cacheKey = GetCacheKeyName(targetType, typeOfOperation, criteria, useLegacyMethods, factoryType);

#if NET8_0_OR_GREATER
if (_methodCache.TryGetValue(cacheKey, out var unloadableCachedMethodInfo))
Expand All @@ -139,32 +150,27 @@ public bool TryFindDataPortalMethod<T>([DynamicallyAccessedMembers(DynamicallyAc
}

var candidates = new List<ScoredMethodInfo>();
var factoryInfo = ObjectFactoryAttribute.GetObjectFactoryAttribute(targetType);
if (factoryInfo != null)
{
if (!TryGetFactoryType(factoryInfo, _applicationContext, throwOnError, out var factoryType))
{
return null;
}

var factoryWalkType = factoryType;
var ftList = new List<System.Reflection.MethodInfo>();
var level = 0;
while (factoryType != null)
while (factoryWalkType != null)
{
ftList.Clear();
if (typeOfOperation == typeof(CreateAttribute))
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.CreateMethodName));
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.CreateMethodName));
else if (typeOfOperation == typeof(FetchAttribute))
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.FetchMethodName));
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.FetchMethodName));
else if (typeOfOperation == typeof(DeleteAttribute))
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.DeleteMethodName));
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.DeleteMethodName));
else if (typeOfOperation == typeof(ExecuteAttribute))
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.ExecuteMethodName));
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.ExecuteMethodName));
else if (typeOfOperation == typeof(CreateChildAttribute))
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == "Child_Create"));
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == "Child_Create"));
else
ftList.AddRange(factoryType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.UpdateMethodName));
factoryType = factoryType.BaseType;
ftList.AddRange(factoryWalkType.GetMethods(_factoryBindingAttr).Where(m => m.Name == factoryInfo.UpdateMethodName));
factoryWalkType = factoryWalkType.BaseType;
candidates.AddRange(ftList.Select(r => new ScoredMethodInfo { MethodInfo = r, Score = level }));
level--;
}
Expand All @@ -173,22 +179,6 @@ public bool TryFindDataPortalMethod<T>([DynamicallyAccessedMembers(DynamicallyAc
var ftlist = targetType.GetMethods(_bindingAttr).Where(m => m.Name == "Child_Create");
candidates.AddRange(ftlist.Select(r => new ScoredMethodInfo { MethodInfo = r, Score = 0 }));
}

static bool TryGetFactoryType(ObjectFactoryAttribute factoryAttribute, ApplicationContext context, bool throwOnError, [NotNullWhen(true)] out Type? factoryType)
{
try
{
var factoryLoader = context.CurrentServiceProvider.GetRequiredService<IObjectFactoryLoader>();
factoryType = factoryLoader.GetFactoryType(factoryAttribute.FactoryTypeName);
}
catch when (!throwOnError)
{
factoryType = null;
return false;
}

return factoryType is not null;
}
}
else // not using factory types
{
Expand Down Expand Up @@ -225,6 +215,22 @@ static bool TryGetFactoryType(ObjectFactoryAttribute factoryAttribute, Applicati
}
}

static bool TryGetFactoryType(ObjectFactoryAttribute factoryAttribute, ApplicationContext context, bool throwOnError, [NotNullWhen(true)] out Type? factoryType)
{
try
{
var factoryLoader = context.CurrentServiceProvider.GetRequiredService<IObjectFactoryLoader>();
factoryType = factoryLoader.GetFactoryType(factoryAttribute.FactoryTypeName);
}
catch when (!throwOnError)
{
factoryType = null;
return false;
}

return factoryType is not null;
}

ScoredMethodInfo? result = null;

if (candidates.Any())
Expand Down Expand Up @@ -431,10 +437,11 @@ private static int CalculateParameterScore(ParameterInfo methodParam, object? c)
return 0;
}

private static string GetCacheKeyName(Type targetType, Type operationType, object?[]? criteria, bool useLegacyMethods)
private static string GetCacheKeyName(Type targetType, Type operationType, object?[]? criteria, bool useLegacyMethods, Type? factoryType = null)
{
var legacy = useLegacyMethods ? "" : "|nolegacy";
return $"{targetType.FullName}.[{operationType.Name.Replace("Attribute", "")}]{GetCriteriaTypeNames(criteria)}{legacy}";
var factory = factoryType is null ? "" : $"|{factoryType.FullName}";
return $"{targetType.FullName}.[{operationType.Name.Replace("Attribute", "")}]{GetCriteriaTypeNames(criteria)}{legacy}{factory}";
}

private static string GetCriteriaTypeNames(object?[]? criteria)
Expand Down Expand Up @@ -619,7 +626,11 @@ private static ParameterInfo[] GetDIParameters(System.Reflection.MethodInfo meth
}
else if (method.IsAsyncTaskObject)
{
return await ((Task<object>)method.DynamicMethod!(obj, plist)).ConfigureAwait(false);
// The method returns Task<T>; Task<T> is invariant so it cannot be cast to
// Task<object>. Use the conversion helper to await it and extract the result.
var returnValue = method.DynamicMethod!(obj, plist);
var convertedTask = (Task<object?>)method.ConvertToTaskObjectMethod!.Invoke(null, [returnValue])!;
return await convertedTask.ConfigureAwait(false);
}
else
{
Expand Down
7 changes: 7 additions & 0 deletions Source/Csla/Reflection/ServiceProviderMethodInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ public class ServiceProviderMethodInfo
/// </summary>
public bool IsAsyncTaskObject { get; set; }
/// <summary>
/// Gets the helper method used to convert a Task of T
/// return value into a Task of object
/// </summary>
internal System.Reflection.MethodInfo? ConvertToTaskObjectMethod { get; private set; }
/// <summary>
/// Gets the DataPortalInfo for the method
/// </summary>
internal DataPortalMethodInfo? DataPortalMethodInfo { get; private set; }
Expand Down Expand Up @@ -118,6 +123,8 @@ public void PrepForInvocation()
}
IsAsyncTask = (MethodInfo.ReturnType == typeof(Task));
IsAsyncTaskObject = (MethodInfo.ReturnType.IsGenericType && (MethodInfo.ReturnType.GetGenericTypeDefinition() == typeof(Task<>)));
if (IsAsyncTaskObject)
ConvertToTaskObjectMethod = TaskConversionHelper.CreateTaskObjectConversionMethodInfo(MethodInfo.ReturnType.GetGenericArguments()[0]);
DataPortalMethodInfo = new DataPortalMethodInfo(MethodInfo);

Initialized = true;
Expand Down
96 changes: 40 additions & 56 deletions Source/Csla/Server/FactoryDataPortal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class FactoryDataPortal : IDataPortalServer
private readonly IObjectFactoryLoader _factoryLoader;
private readonly IDataPortalExceptionInspector _exceptionInspector;
private readonly DataPortalOptions _dataPortalOptions;
private Reflection.ServiceProviderMethodCaller? _serviceProviderMethodCaller;
private Reflection.ServiceProviderMethodCaller ServiceProviderMethodCaller =>
_serviceProviderMethodCaller ??= _applicationContext.CreateInstanceDI<Reflection.ServiceProviderMethodCaller>();

/// <summary>
/// Creates an instance of the type.
Expand All @@ -43,18 +46,39 @@ public FactoryDataPortal(ApplicationContext applicationContext, IObjectFactoryLo

#region Method invokes

private async Task<DataPortalResult> InvokeMethod(string factoryTypeName, DataPortalOperations operation, string methodName, Type objectType, DataPortalContext context, bool isSync)
private async Task<DataPortalResult> InvokeMethod<T>(string factoryTypeName, DataPortalOperations operation, string methodName, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type objectType, object? e, DataPortalContext context, bool isSync)
where T : DataPortalOperationAttribute
{
object factory = _factoryLoader.GetFactory(factoryTypeName);
var eventArgs = new DataPortalEventArgs(context, objectType, null, operation);
// EmptyCriteria is an internal marker for "no criteria"; surface it to the
// factory lifecycle hooks as null, matching the historical behavior.
var eventArg = e is EmptyCriteria ? null : e;
var eventArgs = new DataPortalEventArgs(context, objectType, eventArg, operation);

Reflection.MethodCaller.CallMethodIfImplemented(factory, "Invoke", eventArgs);
object? result;
try
{
Utilities.ThrowIfAsyncMethodOnSyncClient(_applicationContext, isSync, factory, methodName);
var criteria = DataPortal.GetCriteriaArray(e);
if (ServiceProviderMethodCaller.TryFindDataPortalMethod<T>(objectType, criteria, out var method))
{
// DI-aware path: supports multiple criteria parameters and method-level [Inject]
Utilities.ThrowIfAsyncMethodOnSyncClient(_applicationContext, isSync, method.MethodInfo);
result = await ServiceProviderMethodCaller.CallMethodTryAsync(factory, method, criteria).ConfigureAwait(false);
}
else if (e is null or EmptyCriteria)
{
// legacy fallback: parameterless factory method
Utilities.ThrowIfAsyncMethodOnSyncClient(_applicationContext, isSync, factory, methodName);
result = await Reflection.MethodCaller.CallMethodTryAsync(factory, methodName).ConfigureAwait(false);
}
else
{
// legacy fallback: single criteria object
Utilities.ThrowIfAsyncMethodOnSyncClient(_applicationContext, isSync, factory, methodName, e);
result = await Reflection.MethodCaller.CallMethodTryAsync(factory, methodName, e).ConfigureAwait(false);
}

result = await Reflection.MethodCaller.CallMethodTryAsync(factory, methodName).ConfigureAwait(false);
if (result is Exception error)
throw error;

Expand All @@ -66,35 +90,7 @@ private async Task<DataPortalResult> InvokeMethod(string factoryTypeName, DataPo
catch (Exception ex)
{
Reflection.MethodCaller.CallMethodIfImplemented(
factory, "InvokeError", new DataPortalEventArgs(context, objectType, null, operation, ex));
throw;
}
return new DataPortalResult(_applicationContext, result, null);
}

private async Task<DataPortalResult> InvokeMethod(string factoryTypeName, DataPortalOperations operation, string methodName, Type objectType, object e, DataPortalContext context, bool isSync)
{
object factory = _factoryLoader.GetFactory(factoryTypeName);
var eventArgs = new DataPortalEventArgs(context, objectType, e, operation);

Reflection.MethodCaller.CallMethodIfImplemented(factory, "Invoke", eventArgs);
object? result;
try
{
Utilities.ThrowIfAsyncMethodOnSyncClient(_applicationContext, isSync, factory, methodName, e);

result = await Reflection.MethodCaller.CallMethodTryAsync(factory, methodName, e).ConfigureAwait(false);
if (result is Exception error)
throw error;

if (result is Core.ITrackStatus busy && busy.IsBusy)
throw new InvalidOperationException($"{objectType.Name}.IsBusy == true");

Reflection.MethodCaller.CallMethodIfImplemented(factory, "InvokeComplete", eventArgs);
}
catch (Exception ex)
{
Reflection.MethodCaller.CallMethodIfImplemented(factory, "InvokeError", new DataPortalEventArgs(context, objectType, e, operation, ex));
factory, "InvokeError", new DataPortalEventArgs(context, objectType, eventArg, operation, ex));
throw;
}
return new DataPortalResult(_applicationContext, result, null);
Expand All @@ -118,11 +114,7 @@ public async Task<DataPortalResult> Create([DynamicallyAccessedMembers(Dynamical

try
{
DataPortalResult result;
if (criteria is EmptyCriteria)
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Create, context.FactoryInfo.CreateMethodName, objectType, context, isSync).ConfigureAwait(false);
else
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Create, context.FactoryInfo.CreateMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
var result = await InvokeMethod<CreateAttribute>(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Create, context.FactoryInfo.CreateMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
return result;
}
catch (Exception ex)
Expand Down Expand Up @@ -153,11 +145,7 @@ public async Task<DataPortalResult> Fetch([DynamicallyAccessedMembers(Dynamicall

try
{
DataPortalResult result;
if (criteria is EmptyCriteria)
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Fetch, context.FactoryInfo.FetchMethodName, objectType, context, isSync).ConfigureAwait(false);
else
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Fetch, context.FactoryInfo.FetchMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
var result = await InvokeMethod<FetchAttribute>(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Fetch, context.FactoryInfo.FetchMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
return result;
}
catch (Exception ex)
Expand All @@ -169,15 +157,11 @@ public async Task<DataPortalResult> Fetch([DynamicallyAccessedMembers(Dynamicall
}
}

private async Task<DataPortalResult> Execute(Type objectType, object criteria, DataPortalContext context, bool isSync, ObjectFactoryAttribute factoryInfo)
private async Task<DataPortalResult> Execute([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type objectType, object criteria, DataPortalContext context, bool isSync, ObjectFactoryAttribute factoryInfo)
{
try
{
DataPortalResult result;
if (criteria is EmptyCriteria)
result = await InvokeMethod(factoryInfo.FactoryTypeName, DataPortalOperations.Execute, factoryInfo.ExecuteMethodName, objectType, context, isSync).ConfigureAwait(false);
else
result = await InvokeMethod(factoryInfo.FactoryTypeName, DataPortalOperations.Execute, factoryInfo.ExecuteMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
var result = await InvokeMethod<ExecuteAttribute>(factoryInfo.FactoryTypeName, DataPortalOperations.Execute, factoryInfo.ExecuteMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
return result;
}
catch (Exception ex)
Expand Down Expand Up @@ -205,11 +189,15 @@ public async Task<DataPortalResult> Update(ICslaObject obj, DataPortalContext co
{
DataPortalResult result;
if (obj is Core.ICommandObject)
{
methodName = context.FactoryInfo.ExecuteMethodName;
result = await InvokeMethod<ExecuteAttribute>(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Update, methodName, obj.GetType(), obj, context, isSync).ConfigureAwait(false);
}
else
{
methodName = context.FactoryInfo.UpdateMethodName;

result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Update, methodName, obj.GetType(), obj, context, isSync).ConfigureAwait(false);
result = await InvokeMethod<UpdateAttribute>(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Update, methodName, obj.GetType(), obj, context, isSync).ConfigureAwait(false);
}
return result;
}
catch (Exception ex)
Expand All @@ -236,11 +224,7 @@ public async Task<DataPortalResult> Delete([DynamicallyAccessedMembers(Dynamical

try
{
DataPortalResult result;
if (criteria is EmptyCriteria)
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Delete, context.FactoryInfo.DeleteMethodName, objectType, context, isSync).ConfigureAwait(false);
else
result = await InvokeMethod(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Delete, context.FactoryInfo.DeleteMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
var result = await InvokeMethod<DeleteAttribute>(context.FactoryInfo.FactoryTypeName, DataPortalOperations.Delete, context.FactoryInfo.DeleteMethodName, objectType, criteria, context, isSync).ConfigureAwait(false);
return result;
}
catch (Exception ex)
Expand Down
47 changes: 47 additions & 0 deletions Source/tests/Csla.test/ObjectFactory/InjectCommandObject.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//-----------------------------------------------------------------------
// <copyright file="InjectCommandObject.cs" company="Marimer LLC">
// Copyright (c) Marimer LLC. All rights reserved.
// Website: https://cslanet.com
// </copyright>
// <summary>Command object whose factory Execute method uses [Inject]</summary>
//-----------------------------------------------------------------------

namespace Csla.Test.ObjectFactory
{
[Server.ObjectFactory("Csla.Test.ObjectFactory.InjectCommandObjectFactory, Csla.Tests")]
[Serializable]
public class InjectCommandObject : CommandBase<InjectCommandObject>
{
public static readonly PropertyInfo<string> ValueProperty = RegisterProperty<InjectCommandObject, string>(p => p.Value);
public string Value
{
get => ReadProperty(ValueProperty);
set => LoadProperty(ValueProperty, value);
}

public static InjectCommandObject Execute(IDataPortal<InjectCommandObject> dataPortal)
{
var cmd = dataPortal.Create();
return dataPortal.Execute(cmd);
}
}

public class InjectCommandObjectFactory : Csla.Server.ObjectFactory
{
public InjectCommandObjectFactory(ApplicationContext applicationContext) : base(applicationContext)
{
}

[RunLocal]
public object Create()
{
return ApplicationContext.CreateInstanceDI<InjectCommandObject>();
}

public object Execute(InjectCommandObject command, [Inject] IFactoryTestService service)
{
command.Value = service.GetValue();
return command;
}
}
}
Loading
Loading