Skip to content

Commit

Permalink
add APIs for working with the connection pools and client-id
Browse files Browse the repository at this point in the history
  • Loading branch information
mgravell committed Aug 30, 2019
1 parent 16ed3fd commit 3b8b0cd
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 1 deletion.
129 changes: 129 additions & 0 deletions Dapper.ProviderTools/DbConnectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
using System;
using System.Collections.Concurrent;
using System.Data.Common;
using System.Linq.Expressions;
using System.Reflection;

namespace Dapper.ProviderTools
{
/// <summary>
/// Helper utilties for working with database connections
/// </summary>
public static class DbConnectionExtensions
{
/// <summary>
/// Attempt to get the client connection id for a given connection
/// </summary>
public static bool TryGetClientConnectionId(this DbConnection connection, out Guid clientConnectionId)
{
clientConnectionId = default;
return connection != null && ByTypeHelpers.Get(connection.GetType()).TryGetClientConnectionId(
connection, out clientConnectionId);
}

/// <summary>
/// Clear all pools associated with the provided connection type
/// </summary>
public static bool TryClearAllPools(this DbConnection connection)
=> connection != null && ByTypeHelpers.Get(connection.GetType()).TryClearAllPools();

/// <summary>
/// Clear the pools associated with the provided connection
/// </summary>
public static bool TryClearPool(this DbConnection connection)
=> connection != null && ByTypeHelpers.Get(connection.GetType()).TryClearPool(connection);

private sealed class ByTypeHelpers
{
private static readonly ConcurrentDictionary<Type, ByTypeHelpers> s_byType
= new ConcurrentDictionary<Type, ByTypeHelpers>();
private readonly Func<DbConnection, Guid>? _getClientConnectionId;

private readonly Action<DbConnection>? _clearPool;
private readonly Action? _clearAllPools;

public bool TryGetClientConnectionId(DbConnection connection, out Guid clientConnectionId)
{
if (_getClientConnectionId == null)
{
clientConnectionId = default;
return false;
}
clientConnectionId = _getClientConnectionId(connection);
return true;
}

public bool TryClearPool(DbConnection connection)
{
if (_clearPool == null) return false;
_clearPool(connection);
return true;
}

public bool TryClearAllPools()
{
if (_clearAllPools == null) return false;
_clearAllPools();
return true;
}

public static ByTypeHelpers Get(Type type)
{
if (!s_byType.TryGetValue(type, out var value))
{
s_byType[type] = value = new ByTypeHelpers(type);
}
return value;
}

private ByTypeHelpers(Type type)
{
_getClientConnectionId = TryGetInstanceProperty<Guid>("ClientConnectionId", type);

try
{
var clearAllPools = type.GetMethod("ClearAllPools", BindingFlags.Public | BindingFlags.Static,
null, Type.EmptyTypes, null);
if (clearAllPools != null)
{
_clearAllPools = (Action)Delegate.CreateDelegate(typeof(Action), clearAllPools);
}
}
catch { }

try
{
var clearPool = type.GetMethod("ClearPool", BindingFlags.Public | BindingFlags.Static,
null, new[] { type }, null);
if (clearPool != null)
{
var p = Expression.Parameter(typeof(DbConnection), "connection");
var body = Expression.Call(clearPool, Expression.Convert(p, type));
var lambda = Expression.Lambda<Action<DbConnection>>(body, p);
_clearPool = lambda.Compile();
}
}
catch { }
}

private static Func<DbConnection, T>? TryGetInstanceProperty<T>(string name, Type type)
{
try
{
var prop = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance);
if (prop == null || !prop.CanRead) return null;
if (prop.PropertyType != typeof(T)) return null;

var p = Expression.Parameter(typeof(DbConnection), "connection");
var body = Expression.Property(Expression.Convert(p, type), prop);
var lambda = Expression.Lambda<Func<DbConnection, T>>(body, p);
return lambda.Compile();
}
catch
{
return null;
}
}
}
}
}
59 changes: 58 additions & 1 deletion Dapper.Tests/ProviderTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Data.Common;
using System;
using System.Data.Common;
using Dapper.ProviderTools;
using Xunit;

Expand All @@ -24,6 +25,62 @@ public void BulkCopy_MicrosoftDataSqlClient()
}
}

[Fact]
public void ClientId_SystemDataSqlClient()
=> TestClientId<SystemSqlClientProvider>();

[Fact]
public void ClientId_MicrosoftDataSqlClient()
=> TestClientId<MicrosoftSqlClientProvider>();


[Fact]
public void ClearPool_SystemDataSqlClient()
=> ClearPool<SystemSqlClientProvider>();

[Fact]
public void ClearPool_MicrosoftDataSqlClient()
=> ClearPool<MicrosoftSqlClientProvider>();

[Fact]
public void ClearAllPools_SystemDataSqlClient()
=> ClearAllPools<SystemSqlClientProvider>();

[Fact]
public void ClearAllPools_MicrosoftDataSqlClient()
=> ClearAllPools<MicrosoftSqlClientProvider>();

private static void TestClientId<T>()
where T : SqlServerDatabaseProvider, new()
{
var provider = new T();
using (var conn = provider.GetOpenConnection())
{
Assert.True(conn.TryGetClientConnectionId(out var id));
Assert.NotEqual(Guid.Empty, id);
}
}

private static void ClearPool<T>()
where T : SqlServerDatabaseProvider, new()
{
var provider = new T();
using (var conn = provider.GetOpenConnection())
{
Assert.True(conn.TryClearPool());
}
}

private static void ClearAllPools<T>()
where T : SqlServerDatabaseProvider, new()
{
var provider = new T();
using (var conn = provider.GetOpenConnection())
{
Assert.True(conn.TryClearAllPools());
}
}

private static void Test<T>(DbConnection connection)
{
using (var bcp = BulkCopy.TryCreate(connection))
Expand Down

0 comments on commit 3b8b0cd

Please sign in to comment.