Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimated row count hint #31

Merged
merged 5 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Dapper.AOT.Analyzers/AnalyzerReleases.Unshipped.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ DAP025 | Sql | Warning | Diagnostics
DAP026 | Sql | Error | Diagnostics
DAP027 | Performance | Warning | Diagnostics
DAP028 | Performance | Warning | Diagnostics
DAP029 | Library | Info | Diagnostics
DAP030 | Library | Error | Diagnostics
DAP031 | Library | Error | Diagnostics
DAP032 | Library | Error | Diagnostics
DAP100 | Library | Error | Diagnostics
DAP101 | Library | Error | Diagnostics
DAP102 | Library | Error | Diagnostics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
// return Command<type>(...).ExecuteAsync((cast)param, ...);
bool isAsync = HasAny(flags, OperationFlags.Async);
sb.Append("Execute").Append(isAsync ? "Async" : "").Append("(");
sb.Append("(").Append(castType).Append(")param!").Append(");");
sb.Append("(").Append(castType).Append(")param!");
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
}
sb.Append(");");
sb.NewLine().Outdent().NewLine().NewLine();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ void WriteSingleImplementation(
OperationFlags commandTypeMode,
ITypeSymbol? parameterType,
string map, bool cache,
ImmutableArray<IParameterSymbol> methodParameters,
CommandFactoryState factories,
RowReaderState readers,
string? fixedSql)
in ImmutableArray<IParameterSymbol> methodParameters,
in CommandFactoryState factories,
in RowReaderState readers,
string? fixedSql,
in EstimatedRowCountState estimatedRowCount)
{
sb.Append("return ");
if (HasAll(flags, OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
Expand Down Expand Up @@ -188,9 +189,20 @@ static bool IsInbuilt(ITypeSymbol? type, out string? helper)
{
sb.NewLine().Append("#error not supported: ").Append(method.Name).NewLine();
}
if (isAsync)
if (HasAny(flags, OperationFlags.Query) && estimatedRowCount.HasValue)
{
sb.Append(", ").Append(Forward(methodParameters, "cancellationToken"));
if (estimatedRowCount.MemberName is null)
{
sb.Append(", rowCountHint: ").Append(estimatedRowCount.Count);
}
else if (parameterType is not null && !parameterType.IsAnonymousType)
{
sb.Append(", rowCountHint: ((").Append(parameterType).Append(")param!).").Append(estimatedRowCount.MemberName);
}
}
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
}
if (HasAll(flags, OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
{
Expand Down
89 changes: 63 additions & 26 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,10 @@ private bool PreFilter(SyntaxNode node, CancellationToken cancellationToken)
if (!canBeCached) flags &= ~OperationFlags.CacheCommand;
}

var estimatedRowCount = ReadEstimatedRowCount(ctx, op, paramType, ref diagnostics, cancellationToken);
CheckCallValidity(op, flags, ref diagnostics);

return new SourceState(loc, op.TargetMethod, flags, sql, resultType, paramType, parameterMap, diagnostics);
return new SourceState(loc, op.TargetMethod, flags, sql, resultType, paramType, parameterMap, estimatedRowCount, diagnostics);

//static bool HasDiagnostic(object? diagnostics, DiagnosticDescriptor diagnostic)
//{
Expand All @@ -361,6 +362,46 @@ private bool PreFilter(SyntaxNode node, CancellationToken cancellationToken)
static bool TryGetConstantValue<T>(IArgumentOperation op, out T? value)
=> TryGetConstantValueWithSyntax<T>(op, out value, out _);

static EstimatedRowCountState ReadEstimatedRowCount(GeneratorSyntaxContext ctx, IOperation op, ITypeSymbol? paramType, ref object? diagnostics, CancellationToken cancellationToken)
{
string? estimatedRowCountMember = null;
if (paramType is not null)
{
foreach (var member in Inspection.GetMembers(paramType))
{
if (member.IsEstimatedRowCount)
{
if (estimatedRowCountMember is not null)
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MemberRowCountHintDuplicated, member.GetLocation()));
}
estimatedRowCountMember = member.Member.Name;
}
}
}
var attrib = Inspection.GetClosestDapperAttribute(in ctx, op, Types.EstimatedRowCountAttribute, out var attribLoc, cancellationToken);

if (estimatedRowCountMember is not null)
{
if (attrib is not null)
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MethodRowCountHintRedundant, attribLoc, estimatedRowCountMember));
}
return new EstimatedRowCountState(estimatedRowCountMember);
}

if (attrib is not null)
{
if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int i
&& i > 0)
{
return new EstimatedRowCountState(i);
}
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MethodRowCountHintInvalid, attribLoc));
}
return default;
}

static bool TryGetConstantValueWithSyntax<T>(IArgumentOperation op, out T? value, out SyntaxNode? syntax)
{
try
Expand Down Expand Up @@ -510,6 +551,9 @@ select var.Name.StartsWith("@") ? var.Name.Substring(1) : var.Name
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.RowCountDbValue, loc, member.CodeName));
}
}
if (member.Kind != Inspection.ElementMemberKind.None)
{
continue; // not treated as parameters for naming etc purposes
}
var dbName = member.DbName;
Expand Down Expand Up @@ -885,24 +929,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm
foreach (var grp in state.Nodes.Where(x => !HasAny(x.Flags, OperationFlags.DoNotGenerate)).GroupBy(x => x.Group(), CommonComparer.Instance))
{
// first, try to resolve the helper method that we're going to use for this
var (flags, method, parameterType, parameterMap, _) = grp.Key;
int arity = HasAny(flags, OperationFlags.TypedResult) ? 2 : 1, argCount = 8;
var (flags, method, parameterType, parameterMap, _, estimatedRowCount) = grp.Key;
const bool useUnsafe = false;
var helperName = method.Name;
if (helperName == "Query")
{
if (HasAny(flags, OperationFlags.Buffered | OperationFlags.Unbuffered))
{
//dedicated mode
helperName = HasAny(flags, OperationFlags.Buffered) ? "QueryBuffered" : "QueryUnbuffered";
}
else
{
// fallback mode, needs an extra arg to pass in "buffered"
argCount++;
}
}

int usageCount = 0;

foreach (var op in grp.OrderBy(row => row.Location, CommonComparer.Instance))
Expand Down Expand Up @@ -988,7 +1016,7 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm

if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql))
{
WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql);
WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql, estimatedRowCount);
}
}

Expand Down Expand Up @@ -1630,8 +1658,10 @@ sealed class SourceState
public IMethodSymbol Method { get; }
public ITypeSymbol? ResultType { get; }
public ITypeSymbol? ParameterType { get; }
public EstimatedRowCountState EstimatedRowCount { get; }
public SourceState(Location location, IMethodSymbol method, OperationFlags flags, string? sql,
ITypeSymbol? resultType, ITypeSymbol? parameterType, string parameterMap, object? diagnostics = null)
ITypeSymbol? resultType, ITypeSymbol? parameterType, string parameterMap,
EstimatedRowCountState estimatedRowCount, object? diagnostics = null)
{
Location = location;
Flags = flags;
Expand All @@ -1640,6 +1670,7 @@ public SourceState(Location location, IMethodSymbol method, OperationFlags flags
ParameterType = parameterType;
Method = method;
ParameterMap = parameterMap;
EstimatedRowCount = estimatedRowCount;
this.diagnostics = diagnostics;
}

Expand All @@ -1658,21 +1689,25 @@ public SourceState(Location location, IMethodSymbol method, OperationFlags flags
_ => throw new IndexOutOfRangeException(nameof(index)),
};

public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) Group()
=> new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location);
public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) Group()
=> new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location, EstimatedRowCount);
}
private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation)>
private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount)>
{
public static readonly CommonComparer Instance = new();
private CommonComparer() { }

public bool Equals((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) x, (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) y) => x.Flags == y.Flags
public bool Equals(

(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) x,
(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) y) => x.Flags == y.Flags
&& x.ParameterMap == y.ParameterMap
&& SymbolEqualityComparer.Default.Equals(x.Method, y.Method)
&& SymbolEqualityComparer.Default.Equals(x.ParameterType, y.ParameterType)
&& x.UniqueLocation == y.UniqueLocation;
&& x.UniqueLocation == y.UniqueLocation
&& x.EstimatedRowCount.Equals(y.EstimatedRowCount);

public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) obj)
public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) obj)
{
var hash = (int)obj.Flags;
hash *= -47;
Expand All @@ -1689,6 +1724,8 @@ public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol?
{
hash += obj.UniqueLocation.GetHashCode();
}
hash *= -47;
hash += obj.EstimatedRowCount.GetHashCode();
return hash;
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ internal static readonly DiagnosticDescriptor
"Use {0}() instead of Query(...).{1}()", Category.Performance, DiagnosticSeverity.Warning, true),
UseQueryAsList = new("DAP028", "Use AsList instead of ToList",
"Use Query(...).AsList() instead of Query(...).ToList()", Category.Performance, DiagnosticSeverity.Warning, true),
MethodRowCountHintRedundant = new("DAP029", "Method-level row-count hint redundant",
"The [EstimatedRowCount] will be ignored due to parameter member '{0}'", Category.Library, DiagnosticSeverity.Info, true),
MethodRowCountHintInvalid = new("DAP030", "Method-level row-count hint invalid",
"The [EstimatedRowCount] parameters are invalid; a positive integer must be supplied", Category.Library, DiagnosticSeverity.Error, true),
MemberRowCountHintInvalid = new("DAP031", "Member-level row-count hint invalid",
"The [EstimatedRowCount] parameters are invalid; no parameter should be supplied", Category.Library, DiagnosticSeverity.Error, true),
MemberRowCountHintDuplicated = new("DAP032", "Member-level row-count hint duplicated",
"Only a single member should be marked [EstimatedRowCount]", Category.Library, DiagnosticSeverity.Error, true),

// TypeAccessor
TypeAccessorCollectionTypeNotAllowed = new("DAP100", "TypeAccessors does not allow collection types",
Expand Down
1 change: 1 addition & 0 deletions src/Dapper.AOT.Analyzers/Internal/CommandFactoryState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Dapper.Internal;

internal readonly struct CommandFactoryState : IEnumerable<(ITypeSymbol Type, string Map, int Index, int CacheCount)>
{

public CommandFactoryState(Compilation compilation) => systemObject = compilation.GetSpecialType(SpecialType.System_Object);
private readonly ITypeSymbol systemObject;
private readonly Dictionary<(ITypeSymbol Type, string Map, bool Cached), (int Index, int CacheCount)> parameterTypes = new(ParameterTypeMapComparer.Instance);
Expand Down
33 changes: 33 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/EstimatedRowCountState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;

namespace Dapper.Internal;

internal readonly struct EstimatedRowCountState : IEquatable<EstimatedRowCountState>
{
public readonly int Count;
public readonly string? MemberName;
public bool HasValue => Count > 0 || MemberName is not null;

public EstimatedRowCountState(string memberName)
{
Count = 0;
MemberName = memberName;
}

public EstimatedRowCountState(int count)
{
Count = count;
MemberName = null;
}

public override bool Equals(object obj) => obj is EstimatedRowCountState other && Equals(other);

public bool Equals(EstimatedRowCountState other)
=> Count == other.Count && MemberName == other.MemberName;

public override int GetHashCode()
=> Count + (MemberName is null ? 0 : MemberName.GetHashCode());

public override string ToString() => MemberName is null
? Count.ToString() : MemberName;
}
Loading
Loading