From 3b8b0cdc8ad0faf1f58f49585c1bd2a6d79bc32f Mon Sep 17 00:00:00 2001 From: mgravell Date: Fri, 30 Aug 2019 15:23:15 +0100 Subject: [PATCH] add APIs for working with the connection pools and client-id --- .../DbConnectionExtensions.cs | 129 ++++++++++++++++++ Dapper.Tests/ProviderTests.cs | 59 +++++++- 2 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 Dapper.ProviderTools/DbConnectionExtensions.cs diff --git a/Dapper.ProviderTools/DbConnectionExtensions.cs b/Dapper.ProviderTools/DbConnectionExtensions.cs new file mode 100644 index 000000000..05d21d5a7 --- /dev/null +++ b/Dapper.ProviderTools/DbConnectionExtensions.cs @@ -0,0 +1,129 @@ +using System; +using System.Collections.Concurrent; +using System.Data.Common; +using System.Linq.Expressions; +using System.Reflection; + +namespace Dapper.ProviderTools +{ + /// + /// Helper utilties for working with database connections + /// + public static class DbConnectionExtensions + { + /// + /// Attempt to get the client connection id for a given connection + /// + public static bool TryGetClientConnectionId(this DbConnection connection, out Guid clientConnectionId) + { + clientConnectionId = default; + return connection != null && ByTypeHelpers.Get(connection.GetType()).TryGetClientConnectionId( + connection, out clientConnectionId); + } + + /// + /// Clear all pools associated with the provided connection type + /// + public static bool TryClearAllPools(this DbConnection connection) + => connection != null && ByTypeHelpers.Get(connection.GetType()).TryClearAllPools(); + + /// + /// Clear the pools associated with the provided connection + /// + public static bool TryClearPool(this DbConnection connection) + => connection != null && ByTypeHelpers.Get(connection.GetType()).TryClearPool(connection); + + private sealed class ByTypeHelpers + { + private static readonly ConcurrentDictionary s_byType + = new ConcurrentDictionary(); + private readonly Func? _getClientConnectionId; + + private readonly Action? _clearPool; + private readonly Action? _clearAllPools; + + public bool TryGetClientConnectionId(DbConnection connection, out Guid clientConnectionId) + { + if (_getClientConnectionId == null) + { + clientConnectionId = default; + return false; + } + clientConnectionId = _getClientConnectionId(connection); + return true; + } + + public bool TryClearPool(DbConnection connection) + { + if (_clearPool == null) return false; + _clearPool(connection); + return true; + } + + public bool TryClearAllPools() + { + if (_clearAllPools == null) return false; + _clearAllPools(); + return true; + } + + public static ByTypeHelpers Get(Type type) + { + if (!s_byType.TryGetValue(type, out var value)) + { + s_byType[type] = value = new ByTypeHelpers(type); + } + return value; + } + + private ByTypeHelpers(Type type) + { + _getClientConnectionId = TryGetInstanceProperty("ClientConnectionId", type); + + try + { + var clearAllPools = type.GetMethod("ClearAllPools", BindingFlags.Public | BindingFlags.Static, + null, Type.EmptyTypes, null); + if (clearAllPools != null) + { + _clearAllPools = (Action)Delegate.CreateDelegate(typeof(Action), clearAllPools); + } + } + catch { } + + try + { + var clearPool = type.GetMethod("ClearPool", BindingFlags.Public | BindingFlags.Static, + null, new[] { type }, null); + if (clearPool != null) + { + var p = Expression.Parameter(typeof(DbConnection), "connection"); + var body = Expression.Call(clearPool, Expression.Convert(p, type)); + var lambda = Expression.Lambda>(body, p); + _clearPool = lambda.Compile(); + } + } + catch { } + } + + private static Func? TryGetInstanceProperty(string name, Type type) + { + try + { + var prop = type.GetProperty(name, BindingFlags.Public | BindingFlags.Instance); + if (prop == null || !prop.CanRead) return null; + if (prop.PropertyType != typeof(T)) return null; + + var p = Expression.Parameter(typeof(DbConnection), "connection"); + var body = Expression.Property(Expression.Convert(p, type), prop); + var lambda = Expression.Lambda>(body, p); + return lambda.Compile(); + } + catch + { + return null; + } + } + } + } +} diff --git a/Dapper.Tests/ProviderTests.cs b/Dapper.Tests/ProviderTests.cs index bb3c0886d..b51ed2d2e 100644 --- a/Dapper.Tests/ProviderTests.cs +++ b/Dapper.Tests/ProviderTests.cs @@ -1,4 +1,5 @@ -using System.Data.Common; +using System; +using System.Data.Common; using Dapper.ProviderTools; using Xunit; @@ -24,6 +25,62 @@ public void BulkCopy_MicrosoftDataSqlClient() } } + [Fact] + public void ClientId_SystemDataSqlClient() + => TestClientId(); + + [Fact] + public void ClientId_MicrosoftDataSqlClient() + => TestClientId(); + + + [Fact] + public void ClearPool_SystemDataSqlClient() + => ClearPool(); + + [Fact] + public void ClearPool_MicrosoftDataSqlClient() + => ClearPool(); + + [Fact] + public void ClearAllPools_SystemDataSqlClient() + => ClearAllPools(); + + [Fact] + public void ClearAllPools_MicrosoftDataSqlClient() + => ClearAllPools(); + + private static void TestClientId() + where T : SqlServerDatabaseProvider, new() + { + var provider = new T(); + using (var conn = provider.GetOpenConnection()) + { + Assert.True(conn.TryGetClientConnectionId(out var id)); + Assert.NotEqual(Guid.Empty, id); + } + } + + private static void ClearPool() + where T : SqlServerDatabaseProvider, new() + { + var provider = new T(); + using (var conn = provider.GetOpenConnection()) + { + Assert.True(conn.TryClearPool()); + } + } + + private static void ClearAllPools() + where T : SqlServerDatabaseProvider, new() + { + var provider = new T(); + using (var conn = provider.GetOpenConnection()) + { + Assert.True(conn.TryClearAllPools()); + } + } + private static void Test(DbConnection connection) { using (var bcp = BulkCopy.TryCreate(connection))