Skip to content

Commit

Permalink
Estimated row count hint (DapperLib#31)
Browse files Browse the repository at this point in the history
* stab at estimate row count

* intermediate

* nearly there

* fix member-name output

* fixup build output
  • Loading branch information
mgravell authored Aug 14, 2023
1 parent 9070118 commit d563d38
Show file tree
Hide file tree
Showing 38 changed files with 845 additions and 143 deletions.
4 changes: 4 additions & 0 deletions src/Dapper.AOT.Analyzers/AnalyzerReleases.Unshipped.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ DAP025 | Sql | Warning | Diagnostics
DAP026 | Sql | Error | Diagnostics
DAP027 | Performance | Warning | Diagnostics
DAP028 | Performance | Warning | Diagnostics
DAP029 | Library | Info | Diagnostics
DAP030 | Library | Error | Diagnostics
DAP031 | Library | Error | Diagnostics
DAP032 | Library | Error | Diagnostics
DAP100 | Library | Error | Diagnostics
DAP101 | Library | Error | Diagnostics
DAP102 | Library | Error | Diagnostics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ void WriteMultiExecExpression(ITypeSymbol elementType, string castType)
// return Command<type>(...).ExecuteAsync((cast)param, ...);
bool isAsync = HasAny(flags, OperationFlags.Async);
sb.Append("Execute").Append(isAsync ? "Async" : "").Append("(");
sb.Append("(").Append(castType).Append(")param!").Append(");");
sb.Append("(").Append(castType).Append(")param!");
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
}
sb.Append(");");
sb.NewLine().Outdent().NewLine().NewLine();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ void WriteSingleImplementation(
OperationFlags commandTypeMode,
ITypeSymbol? parameterType,
string map, bool cache,
ImmutableArray<IParameterSymbol> methodParameters,
CommandFactoryState factories,
RowReaderState readers,
string? fixedSql)
in ImmutableArray<IParameterSymbol> methodParameters,
in CommandFactoryState factories,
in RowReaderState readers,
string? fixedSql,
in EstimatedRowCountState estimatedRowCount)
{
sb.Append("return ");
if (HasAll(flags, OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
Expand Down Expand Up @@ -188,9 +189,20 @@ static bool IsInbuilt(ITypeSymbol? type, out string? helper)
{
sb.NewLine().Append("#error not supported: ").Append(method.Name).NewLine();
}
if (isAsync)
if (HasAny(flags, OperationFlags.Query) && estimatedRowCount.HasValue)
{
sb.Append(", ").Append(Forward(methodParameters, "cancellationToken"));
if (estimatedRowCount.MemberName is null)
{
sb.Append(", rowCountHint: ").Append(estimatedRowCount.Count);
}
else if (parameterType is not null && !parameterType.IsAnonymousType)
{
sb.Append(", rowCountHint: ((").Append(parameterType).Append(")param!).").Append(estimatedRowCount.MemberName);
}
}
if (isAsync && HasParam(methodParameters, "cancellationToken"))
{
sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken"));
}
if (HasAll(flags, OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered))
{
Expand Down
89 changes: 63 additions & 26 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,10 @@ private bool PreFilter(SyntaxNode node, CancellationToken cancellationToken)
if (!canBeCached) flags &= ~OperationFlags.CacheCommand;
}

var estimatedRowCount = ReadEstimatedRowCount(ctx, op, paramType, ref diagnostics, cancellationToken);
CheckCallValidity(op, flags, ref diagnostics);

return new SourceState(loc, op.TargetMethod, flags, sql, resultType, paramType, parameterMap, diagnostics);
return new SourceState(loc, op.TargetMethod, flags, sql, resultType, paramType, parameterMap, estimatedRowCount, diagnostics);

//static bool HasDiagnostic(object? diagnostics, DiagnosticDescriptor diagnostic)
//{
Expand All @@ -361,6 +362,46 @@ private bool PreFilter(SyntaxNode node, CancellationToken cancellationToken)
static bool TryGetConstantValue<T>(IArgumentOperation op, out T? value)
=> TryGetConstantValueWithSyntax<T>(op, out value, out _);

static EstimatedRowCountState ReadEstimatedRowCount(GeneratorSyntaxContext ctx, IOperation op, ITypeSymbol? paramType, ref object? diagnostics, CancellationToken cancellationToken)
{
string? estimatedRowCountMember = null;
if (paramType is not null)
{
foreach (var member in Inspection.GetMembers(paramType))
{
if (member.IsEstimatedRowCount)
{
if (estimatedRowCountMember is not null)
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MemberRowCountHintDuplicated, member.GetLocation()));
}
estimatedRowCountMember = member.Member.Name;
}
}
}
var attrib = Inspection.GetClosestDapperAttribute(in ctx, op, Types.EstimatedRowCountAttribute, out var attribLoc, cancellationToken);

if (estimatedRowCountMember is not null)
{
if (attrib is not null)
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MethodRowCountHintRedundant, attribLoc, estimatedRowCountMember));
}
return new EstimatedRowCountState(estimatedRowCountMember);
}

if (attrib is not null)
{
if (attrib.ConstructorArguments.Length == 1 && attrib.ConstructorArguments[0].Value is int i
&& i > 0)
{
return new EstimatedRowCountState(i);
}
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.MethodRowCountHintInvalid, attribLoc));
}
return default;
}

static bool TryGetConstantValueWithSyntax<T>(IArgumentOperation op, out T? value, out SyntaxNode? syntax)
{
try
Expand Down Expand Up @@ -510,6 +551,9 @@ select var.Name.StartsWith("@") ? var.Name.Substring(1) : var.Name
{
Diagnostics.Add(ref diagnostics, Diagnostic.Create(Diagnostics.RowCountDbValue, loc, member.CodeName));
}
}
if (member.Kind != Inspection.ElementMemberKind.None)
{
continue; // not treated as parameters for naming etc purposes
}
var dbName = member.DbName;
Expand Down Expand Up @@ -885,24 +929,8 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm
foreach (var grp in state.Nodes.Where(x => !HasAny(x.Flags, OperationFlags.DoNotGenerate)).GroupBy(x => x.Group(), CommonComparer.Instance))
{
// first, try to resolve the helper method that we're going to use for this
var (flags, method, parameterType, parameterMap, _) = grp.Key;
int arity = HasAny(flags, OperationFlags.TypedResult) ? 2 : 1, argCount = 8;
var (flags, method, parameterType, parameterMap, _, estimatedRowCount) = grp.Key;
const bool useUnsafe = false;
var helperName = method.Name;
if (helperName == "Query")
{
if (HasAny(flags, OperationFlags.Buffered | OperationFlags.Unbuffered))
{
//dedicated mode
helperName = HasAny(flags, OperationFlags.Buffered) ? "QueryBuffered" : "QueryUnbuffered";
}
else
{
// fallback mode, needs an extra arg to pass in "buffered"
argCount++;
}
}

int usageCount = 0;

foreach (var op in grp.OrderBy(row => row.Location, CommonComparer.Instance))
Expand Down Expand Up @@ -988,7 +1016,7 @@ private void Generate(SourceProductionContext ctx, (Compilation Compilation, Imm

if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql))
{
WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql);
WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql, estimatedRowCount);
}
}

Expand Down Expand Up @@ -1630,8 +1658,10 @@ sealed class SourceState
public IMethodSymbol Method { get; }
public ITypeSymbol? ResultType { get; }
public ITypeSymbol? ParameterType { get; }
public EstimatedRowCountState EstimatedRowCount { get; }
public SourceState(Location location, IMethodSymbol method, OperationFlags flags, string? sql,
ITypeSymbol? resultType, ITypeSymbol? parameterType, string parameterMap, object? diagnostics = null)
ITypeSymbol? resultType, ITypeSymbol? parameterType, string parameterMap,
EstimatedRowCountState estimatedRowCount, object? diagnostics = null)
{
Location = location;
Flags = flags;
Expand All @@ -1640,6 +1670,7 @@ public SourceState(Location location, IMethodSymbol method, OperationFlags flags
ParameterType = parameterType;
Method = method;
ParameterMap = parameterMap;
EstimatedRowCount = estimatedRowCount;
this.diagnostics = diagnostics;
}

Expand All @@ -1658,21 +1689,25 @@ public SourceState(Location location, IMethodSymbol method, OperationFlags flags
_ => throw new IndexOutOfRangeException(nameof(index)),
};

public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) Group()
=> new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location);
public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) Group()
=> new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location, EstimatedRowCount);
}
private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation)>
private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount)>
{
public static readonly CommonComparer Instance = new();
private CommonComparer() { }

public bool Equals((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) x, (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) y) => x.Flags == y.Flags
public bool Equals(

(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) x,
(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) y) => x.Flags == y.Flags
&& x.ParameterMap == y.ParameterMap
&& SymbolEqualityComparer.Default.Equals(x.Method, y.Method)
&& SymbolEqualityComparer.Default.Equals(x.ParameterType, y.ParameterType)
&& x.UniqueLocation == y.UniqueLocation;
&& x.UniqueLocation == y.UniqueLocation
&& x.EstimatedRowCount.Equals(y.EstimatedRowCount);

public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation) obj)
public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, EstimatedRowCountState EstimatedRowCount) obj)
{
var hash = (int)obj.Flags;
hash *= -47;
Expand All @@ -1689,6 +1724,8 @@ public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol?
{
hash += obj.UniqueLocation.GetHashCode();
}
hash *= -47;
hash += obj.EstimatedRowCount.GetHashCode();
return hash;
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/Dapper.AOT.Analyzers/CodeAnalysis/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ internal static readonly DiagnosticDescriptor
"Use {0}() instead of Query(...).{1}()", Category.Performance, DiagnosticSeverity.Warning, true),
UseQueryAsList = new("DAP028", "Use AsList instead of ToList",
"Use Query(...).AsList() instead of Query(...).ToList()", Category.Performance, DiagnosticSeverity.Warning, true),
MethodRowCountHintRedundant = new("DAP029", "Method-level row-count hint redundant",
"The [EstimatedRowCount] will be ignored due to parameter member '{0}'", Category.Library, DiagnosticSeverity.Info, true),
MethodRowCountHintInvalid = new("DAP030", "Method-level row-count hint invalid",
"The [EstimatedRowCount] parameters are invalid; a positive integer must be supplied", Category.Library, DiagnosticSeverity.Error, true),
MemberRowCountHintInvalid = new("DAP031", "Member-level row-count hint invalid",
"The [EstimatedRowCount] parameters are invalid; no parameter should be supplied", Category.Library, DiagnosticSeverity.Error, true),
MemberRowCountHintDuplicated = new("DAP032", "Member-level row-count hint duplicated",
"Only a single member should be marked [EstimatedRowCount]", Category.Library, DiagnosticSeverity.Error, true),

// TypeAccessor
TypeAccessorCollectionTypeNotAllowed = new("DAP100", "TypeAccessors does not allow collection types",
Expand Down
1 change: 1 addition & 0 deletions src/Dapper.AOT.Analyzers/Internal/CommandFactoryState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace Dapper.Internal;

internal readonly struct CommandFactoryState : IEnumerable<(ITypeSymbol Type, string Map, int Index, int CacheCount)>
{

public CommandFactoryState(Compilation compilation) => systemObject = compilation.GetSpecialType(SpecialType.System_Object);
private readonly ITypeSymbol systemObject;
private readonly Dictionary<(ITypeSymbol Type, string Map, bool Cached), (int Index, int CacheCount)> parameterTypes = new(ParameterTypeMapComparer.Instance);
Expand Down
33 changes: 33 additions & 0 deletions src/Dapper.AOT.Analyzers/Internal/EstimatedRowCountState.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;

namespace Dapper.Internal;

internal readonly struct EstimatedRowCountState : IEquatable<EstimatedRowCountState>
{
public readonly int Count;
public readonly string? MemberName;
public bool HasValue => Count > 0 || MemberName is not null;

public EstimatedRowCountState(string memberName)
{
Count = 0;
MemberName = memberName;
}

public EstimatedRowCountState(int count)
{
Count = count;
MemberName = null;
}

public override bool Equals(object obj) => obj is EstimatedRowCountState other && Equals(other);

public bool Equals(EstimatedRowCountState other)
=> Count == other.Count && MemberName == other.MemberName;

public override int GetHashCode()
=> Count + (MemberName is null ? 0 : MemberName.GetHashCode());

public override string ToString() => MemberName is null
? Count.ToString() : MemberName;
}
Loading

0 comments on commit d563d38

Please sign in to comment.