Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen] Always specify grain extension interface for grain extension calls #9009

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/Orleans.CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,19 +674,26 @@ private ProxyInterfaceDescription GetInvokableInterfaceDescription(INamedTypeSym
return description;
}

internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method, bool hasCollision)
internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol interfaceType, IMethodSymbol method)
{
var originalMethod = method.OriginalDefinition;
var proxyBaseInfo = GetProxyBase(interfaceType);
var invokableId = new InvokableMethodId(proxyBaseInfo, originalMethod);

// For extensions, we want to ensure that the containing type is always the extension.
// This ensures that we will always know which 'component' to get in our SetTarget method.
// If the type is not an extension, use the original method definition's containing type.
// This is the interface where the type was originally defined.
var containingType = proxyBaseInfo.IsExtension ? interfaceType : originalMethod.ContainingType;

var invokableId = new InvokableMethodId(proxyBaseInfo, containingType, originalMethod);
var interfaceDescription = GetInvokableInterfaceDescription(invokableId.ProxyBase.ProxyBaseType, interfaceType);

// Get or generate an invokable for the original method definition.
if (!MetadataModel.GeneratedInvokables.TryGetValue(invokableId, out var generatedInvokable))
{
if (!_invokableMethodDescriptions.TryGetValue(invokableId, out var methodDescription))
{
methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId);
methodDescription = _invokableMethodDescriptions[invokableId] = InvokableMethodDescription.Create(invokableId, containingType);
}

generatedInvokable = MetadataModel.GeneratedInvokables[invokableId] = InvokableGenerator.Generate(methodDescription);
Expand All @@ -706,12 +713,12 @@ internal ProxyMethodDescription GetProxyMethodDescription(INamedTypeSymbol inter
}
}

var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method, hasCollision);
var proxyMethodDescription = ProxyMethodDescription.Create(interfaceDescription, generatedInvokable, method);

// For backwards compatibility, generate invokers for the specific implementation types as well, where they differ.
if (Options.GenerateCompatibilityInvokers && !SymbolEqualityComparer.Default.Equals(method.OriginalDefinition.ContainingType, interfaceType))
{
var compatInvokableId = new InvokableMethodId(proxyBaseInfo, method);
var compatInvokableId = new InvokableMethodId(proxyBaseInfo, interfaceType, method);
var compatMethodDescription = InvokableMethodDescription.Create(compatInvokableId, interfaceType);
var compatInvokable = InvokableGenerator.Generate(compatMethodDescription);
AddMember(compatInvokable.GeneratedNamespace, compatInvokable.ClassDeclarationSyntax);
Expand Down
16 changes: 13 additions & 3 deletions src/Orleans.CodeGenerator/InvokableGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,26 @@ public static CompoundTypeAliasComponent[] GetCompoundTypeAliasComponents(
INamedTypeSymbol containingInterface,
string methodId)
{
var proxyBaseComponents = invokableId.ProxyBase.CompositeAliasComponents;
var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + 2];
var proxyBase = invokableId.ProxyBase;
var proxyBaseComponents = proxyBase.CompositeAliasComponents;
var extensionArgCount = proxyBase.IsExtension ? 1 : 0;
var alias = new CompoundTypeAliasComponent[1 + proxyBaseComponents.Length + extensionArgCount + 2];
alias[0] = new("inv");
for (var i = 0; i < proxyBaseComponents.Length; i++)
{
alias[i + 1] = proxyBaseComponents[i];
}

alias[1 + proxyBaseComponents.Length] = new(containingInterface);
alias[1 + proxyBaseComponents.Length + 1] = new(methodId);

// For grain extensions, also explicitly include the method's containing type.
// This is to distinguish between different extension methods with the same id (eg, alias) but different containing types.
if (proxyBase.IsExtension)
{
alias[1 + proxyBaseComponents.Length + 1] = new(invokableId.Method.ContainingType);
}

alias[1 + proxyBaseComponents.Length + extensionArgCount + 1] = new(methodId);
return alias;
}

Expand Down
6 changes: 3 additions & 3 deletions src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ namespace Orleans.CodeGenerator
/// </summary>
internal sealed class InvokableMethodDescription : IEquatable<InvokableMethodDescription>
{
public static InvokableMethodDescription Create(InvokableMethodId method, INamedTypeSymbol containingType = null) => new(method, containingType);
public static InvokableMethodDescription Create(InvokableMethodId method, INamedTypeSymbol containingType) => new(method, containingType);

private InvokableMethodDescription(InvokableMethodId invokableId, INamedTypeSymbol containingType)
{
Key = invokableId;
ContainingInterface = containingType ?? invokableId.Method.ContainingType;
ContainingInterface = containingType;
GeneratedMethodId = CodeGenerator.CreateHashedMethodId(Method);
MethodId = CodeGenerator.GetId(Method)?.ToString(CultureInfo.InvariantCulture) ?? CodeGenerator.GetAlias(Method) ?? GeneratedMethodId;

Expand Down Expand Up @@ -209,6 +209,6 @@ static bool TryGetNamedArgument(ImmutableArray<KeyValuePair<string, TypedConstan
public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key);
public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd);
public override int GetHashCode() => Key.GetHashCode();
public override string ToString() => $"{ProxyBase}/{Method.ContainingType.Name}/{Method.Name}";
public override string ToString() => $"{ProxyBase}/{ContainingInterface.Name}/{Method.Name}";
}
}
33 changes: 21 additions & 12 deletions src/Orleans.CodeGenerator/Model/InvokableMethodId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,39 @@ namespace Orleans.CodeGenerator
/// <summary>
/// Identifies an invokable method.
/// </summary>
internal readonly struct InvokableMethodId : IEquatable<InvokableMethodId>
internal readonly struct InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, INamedTypeSymbol interfaceType, IMethodSymbol method) : IEquatable<InvokableMethodId>
{
public InvokableMethodId(InvokableMethodProxyBase proxyBaseInfo, IMethodSymbol method)
{
ProxyBase = proxyBaseInfo;
Method = method;
}

/// <summary>
/// Gets the proxy base information for the method (eg, GrainReference, whether it is an extension).
/// </summary>
public InvokableMethodProxyBase ProxyBase { get; }
public InvokableMethodProxyBase ProxyBase { get; } = proxyBaseInfo;

/// <summary>
/// Gets the method symbol.
/// </summary>
public IMethodSymbol Method { get; }
public IMethodSymbol Method { get; } = method;

/// <summary>
/// Gets the containing interface symbol.
/// </summary>
public INamedTypeSymbol InterfaceType { get; } = interfaceType;

public bool Equals(InvokableMethodId other) =>
ProxyBase.Equals(other.ProxyBase)
&& SymbolEqualityComparer.Default.Equals(Method, other.Method);
&& SymbolEqualityComparer.Default.Equals(Method, other.Method)
&& SymbolEqualityComparer.Default.Equals(InterfaceType, other.InterfaceType);

public override bool Equals(object obj) => obj is InvokableMethodId imd && Equals(imd);
public override int GetHashCode() => ProxyBase.GetHashCode() * 17 ^ SymbolEqualityComparer.Default.GetHashCode(Method);
public override string ToString() => $"{ProxyBase}/{Method.ContainingType.Name}/{Method.Name}";
public override int GetHashCode()
{
unchecked
{
return ProxyBase.GetHashCode()
* 17 ^ SymbolEqualityComparer.Default.GetHashCode(Method)
* 17 ^ SymbolEqualityComparer.Default.GetHashCode(InterfaceType);
}
}

public override string ToString() => $"{ProxyBase}/{InterfaceType.Name}/{Method.Name}";
}
}
28 changes: 7 additions & 21 deletions src/Orleans.CodeGenerator/Model/ProxyInterfaceDescription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,19 @@ static string GetTypeParameterName(HashSet<string> names, ITypeParameterSymbol t

public CodeGenerator CodeGenerator { get; }

private List<ProxyMethodDescription> GetMethods(INamedTypeSymbol symbol)
private List<ProxyMethodDescription> GetMethods()
{
#pragma warning disable RS1024 // Symbols should be compared for equality
var methods = new Dictionary<IMethodSymbol, bool>(MethodSignatureComparer.Default);
#pragma warning restore RS1024 // Symbols should be compared for equality
foreach (var iface in GetAllInterfaces(symbol))
var result = new List<ProxyMethodDescription>();
foreach (var iface in GetAllInterfaces(InterfaceType))
{
foreach (var method in iface.GetDeclaredInstanceMembers<IMethodSymbol>())
{
if (methods.TryGetValue(method, out _))
{
methods[method] = true;
continue;
}

methods.Add(method, false);
var methodDescription = CodeGenerator.GetProxyMethodDescription(InterfaceType, method: method);
result.Add(methodDescription);
}
}

var res = new List<ProxyMethodDescription>();
foreach (var pair in methods)
{
var methodDescription = CodeGenerator.GetProxyMethodDescription(symbol, method: pair.Key, hasCollision: pair.Value);
res.Add(methodDescription);
}

return res;
return result;

static IEnumerable<INamedTypeSymbol> GetAllInterfaces(INamedTypeSymbol s)
{
Expand All @@ -117,7 +103,7 @@ static IEnumerable<INamedTypeSymbol> GetAllInterfaces(INamedTypeSymbol s)

public string Name { get; }
public INamedTypeSymbol InterfaceType { get; }
public List<ProxyMethodDescription> Methods => _methods ??= GetMethods(InterfaceType);
public List<ProxyMethodDescription> Methods => _methods ??= GetMethods();
public SemanticModel SemanticModel { get; }
public string GeneratedNamespace { get; }
public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; }
Expand Down
8 changes: 3 additions & 5 deletions src/Orleans.CodeGenerator/Model/ProxyMethodDescription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ internal class ProxyMethodDescription : IEquatable<ProxyMethodDescription>
public static ProxyMethodDescription Create(
ProxyInterfaceDescription proxyInterface,
GeneratedInvokableDescription generatedInvokable,
IMethodSymbol method,
bool hasCollision) => new(proxyInterface, generatedInvokable, method, hasCollision);
IMethodSymbol method)
=> new(proxyInterface, generatedInvokable, method);

private ProxyMethodDescription(ProxyInterfaceDescription proxyInterface, GeneratedInvokableDescription generatedInvokable, IMethodSymbol method, bool hasCollision)
private ProxyMethodDescription(ProxyInterfaceDescription proxyInterface, GeneratedInvokableDescription generatedInvokable, IMethodSymbol method)
{
_originalInvokable = generatedInvokable;
Method = method;
ProxyInterface = proxyInterface;
HasCollision = hasCollision;

TypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>();
MethodTypeParameters = new List<(string Name, ITypeParameterSymbol Parameter)>();
Expand Down Expand Up @@ -80,7 +79,6 @@ static string GetTypeParameterName(HashSet<string> names, ITypeParameterSymbol t
public ConstructedGeneratedInvokableDescription GeneratedInvokable { get; }
public ProxyInterfaceDescription ProxyInterface { get; }

public bool HasCollision { get; }
public IMethodSymbol Method { get; }
public InvokableMethodId InvokableId { get; }
public List<(string Name, ITypeParameterSymbol Parameter)> TypeParameters { get; }
Expand Down
22 changes: 2 additions & 20 deletions src/Orleans.CodeGenerator/ProxyGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,8 @@ MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescripti
declaration = declaration.WithModifiers(TokenList(Token(SyntaxKind.AsyncKeyword)));
}

if (methodDescription.HasCollision)
{
declaration = declaration.WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword)));

// Type parameter constrains are not valid on explicit interface definitions
var typeParameters = SyntaxFactoryUtility.GetTypeParameterConstraints(methodDescription.MethodTypeParameters);
foreach (var (name, constraints) in typeParameters)
{
if (constraints.Count > 0)
{
declaration = declaration.AddConstraintClauses(
TypeParameterConstraintClause(name).AddConstraints(constraints.ToArray()));
}
}
}
else
{
var explicitInterfaceSpecifier = ExplicitInterfaceSpecifier(methodDescription.Method.ContainingType.ToNameSyntax());
declaration = declaration.WithExplicitInterfaceSpecifier(explicitInterfaceSpecifier);
}
var explicitInterfaceSpecifier = ExplicitInterfaceSpecifier(methodDescription.Method.ContainingType.ToNameSyntax());
declaration = declaration.WithExplicitInterfaceSpecifier(explicitInterfaceSpecifier);

if (methodDescription.MethodTypeParameters.Count > 0)
{
Expand Down
7 changes: 4 additions & 3 deletions src/Orleans.Core.Abstractions/Core/IGrainCallContext.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#nullable enable
using System.Reflection;
using System.Threading.Tasks;
using Orleans.Runtime;
Expand Down Expand Up @@ -74,12 +75,12 @@ public interface IGrainCallContext
/// <summary>
/// Gets or sets the result.
/// </summary>
object Result { get; set; }
object? Result { get; set; }

/// <summary>
/// Gets or sets the response.
/// </summary>
Response Response { get; set; }
Response? Response { get; set; }

/// <summary>
/// Invokes the request.
Expand Down Expand Up @@ -114,6 +115,6 @@ public interface IOutgoingGrainCallContext : IGrainCallContext
/// <summary>
/// Gets the grain context of the sender.
/// </summary>
public IGrainContext SourceContext { get; }
public IGrainContext? SourceContext { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Orleans.Runtime;
/// Following the <u>power of k-choices</u> algorithm, K silos are picked as potential targets, where K is equal to the square root of the number of silos.
/// Out of those K silos, the one with the lowest score is chosen for placing the activation. Normalization ensures that each property contributes proportionally
/// to the overall score. You can adjust the weights based on your specific requirements and priorities for load balancing.
/// In addition to normalization, an <u>online adaptiv</u> algorithm provides a smoothing effect (filters out high frequency components) and avoids rapid signal
/// In addition to normalization, an <u>online adaptive</u> algorithm provides a smoothing effect (filters out high frequency components) and avoids rapid signal
/// drops by transforming it into a polynomial-like decay process. This contributes to avoiding resource saturation on the silos and especially newly joined silos.</para>
/// <para>Silos which are overloaded by definition of the load shedding mechanism are not considered as candidates for new placements.</para>
/// <para><i>This placement strategy is configured by adding the <see cref="Placement.ResourceOptimizedPlacementAttribute"/> attribute to a grain.</i></para>
Expand Down
1 change: 0 additions & 1 deletion test/DefaultCluster.Tests/GrainInterfaceHierarchyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ public async Task DoSomethingValidateSingleGrainTest()
Assert.Equal(11, await doSomethingCombinedGrain.GetA());
Assert.Equal(11, await doSomethingCombinedGrain.GetB());
Assert.Equal(11, await doSomethingCombinedGrain.GetC());

}
}
}
14 changes: 14 additions & 0 deletions test/DefaultCluster.Tests/PolymorphicInterfaceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ public async Task Polymorphic_ServiceType()
Assert.Equal("B3", await serviceRef.B3Method());
}

[Fact, TestCategory("BVT"), TestCategory("Cast")]
public async Task Polymorphic_InheritedMethodAmbiguity()
{
// Tests interface inheritance hierarchies which involve duplicate method names, requiring casting to resolve the ambiguity.
var grainFullName = typeof(ServiceType).FullName;
var serviceRef = this.GrainFactory.GetGrain<IServiceType>(GetRandomGrainId(), grainFullName);
var ia = (IA)serviceRef;
var ib = (IB)serviceRef;
var ic = (IC)serviceRef;
Assert.Equal("IA", await ia.CommonMethod());
Assert.Equal("IB", await ib.CommonMethod());
Assert.Equal("IC", await ic.CommonMethod());
}

/// <summary>
/// This unit test should consolidate all the use cases we are trying to cover with regard to polymorphic grain references
/// </summary>
Expand Down
5 changes: 4 additions & 1 deletion test/Grains/TestGrainInterfaces/UnitTestGrainInterfaces.cs
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
namespace UnitTests.GrainInterfaces
namespace UnitTests.GrainInterfaces
{
public interface IA : IGrainWithIntegerKey
{
Task<string> CommonMethod();
Task<string> A1Method();
Task<string> A2Method();
Task<string> A3Method();
}

public interface IB : IGrainWithIntegerKey
{
Task<string> CommonMethod();
Task<string> B1Method();
Task<string> B2Method();
Task<string> B3Method();
}

public interface IC : IA, IB
{
new Task<string> CommonMethod();
Task<string> C1Method();
Task<string> C2Method();
Task<string> C3Method();
Expand Down
Loading
Loading