Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Davoud Eshtehari committed May 6, 2021
1 parent ac69f38 commit fd6c475
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,16 +392,15 @@ void Cancel()
Socket availableSocket = null;
try
{
int n = 0; // Socket index

// We go through the IP list twice.
// In the first traversal, we only try to connect with the preferedIPFamilies[0].
// In the second traversal, we only try to connect with the preferedIPFamilies[1].
// For UsePlatformDefault preference, we do traversal once.
for (int i = 0; i < preferedIPFamilies.Length; ++i)
{
foreach (IPAddress ipAddress in ipAddresses)
for (int n = 0; n < ipAddresses.Length; n++)
{
IPAddress ipAddress = ipAddresses[n];
try
{
if (ipAddress != null)
Expand Down Expand Up @@ -443,7 +442,6 @@ void Cancel()
sockets[n] = null;
}
}
n++;
}
}
catch (Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal static class ConnectionHelper
private static Type s_dbConnectionInternal = s_MicrosoftDotData.GetType("Microsoft.Data.ProviderBase.DbConnectionInternal");
private static Type s_tdsParser = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.TdsParser");
private static Type s_tdsParserStateObject = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.TdsParserStateObject");
private static Type s_SQLDNSInfo = s_MicrosoftDotData.GetType("Microsoft.Data.SqlClient.SQLDNSInfo");
private static PropertyInfo s_sqlConnectionInternalConnection = s_sqlConnection.GetProperty("InnerConnection", BindingFlags.Instance | BindingFlags.NonPublic);
private static PropertyInfo s_dbConnectionInternalPool = s_dbConnectionInternal.GetProperty("Pool", BindingFlags.Instance | BindingFlags.NonPublic);
private static MethodInfo s_dbConnectionInternalIsConnectionAlive = s_dbConnectionInternal.GetMethod("IsConnectionAlive", BindingFlags.Instance | BindingFlags.NonPublic);
Expand All @@ -26,6 +27,11 @@ internal static class ConnectionHelper
private static FieldInfo s_tdsParserStateObjectProperty = s_tdsParser.GetField("_physicalStateObj", BindingFlags.Instance | BindingFlags.NonPublic);
private static FieldInfo s_enforceTimeoutDelayProperty = s_tdsParserStateObject.GetField("_enforceTimeoutDelay", BindingFlags.Instance | BindingFlags.NonPublic);
private static FieldInfo s_enforcedTimeoutDelayInMilliSeconds = s_tdsParserStateObject.GetField("_enforcedTimeoutDelayInMilliSeconds", BindingFlags.Instance | BindingFlags.NonPublic);
private static FieldInfo s_pendingSQLDNSObject = s_sqlInternalConnectionTds.GetField("pendingSQLDNSObject", BindingFlags.Instance | BindingFlags.NonPublic);
private static PropertyInfo s_pendingSQLDNS_FQDN = s_SQLDNSInfo.GetProperty("FQDN", BindingFlags.Instance | BindingFlags.Public);
private static PropertyInfo s_pendingSQLDNS_AddrIPv4 = s_SQLDNSInfo.GetProperty("AddrIPv4", BindingFlags.Instance | BindingFlags.Public);
private static PropertyInfo s_pendingSQLDNS_AddrIPv6 = s_SQLDNSInfo.GetProperty("AddrIPv6", BindingFlags.Instance | BindingFlags.Public);
private static PropertyInfo s_pendingSQLDNS_Port = s_SQLDNSInfo.GetProperty("Port", BindingFlags.Instance | BindingFlags.Public);

public static object GetConnectionPool(object internalConnection)
{
Expand Down Expand Up @@ -79,5 +85,17 @@ public static void SetEnforcedTimeout(this SqlConnection connection, bool enforc
s_enforceTimeoutDelayProperty.SetValue(stateObj, enforce);
s_enforcedTimeoutDelayInMilliSeconds.SetValue(stateObj, timeout);
}

public static Tuple<string, string, string, string> GetSQLDNSInfo(this SqlConnection connection)
{
object internalConnection = GetInternalConnection(connection);
VerifyObjectIsInternalConnection(internalConnection);
object pendingSQLDNSInfo = s_pendingSQLDNSObject.GetValue(internalConnection);
string fqdn = s_pendingSQLDNS_FQDN.GetValue(pendingSQLDNSInfo) as string;
string ipv4 = s_pendingSQLDNS_AddrIPv4.GetValue(pendingSQLDNSInfo) as string;
string ipv6 = s_pendingSQLDNS_AddrIPv6.GetValue(pendingSQLDNSInfo) as string;
string port = s_pendingSQLDNS_Port.GetValue(pendingSQLDNSInfo) as string;
return new Tuple<string, string, string, string>(fqdn, ipv4, ipv6, port);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals;
using Xunit;

using static Microsoft.Data.SqlClient.ManualTesting.Tests.DataTestUtility;
Expand All @@ -16,19 +22,68 @@ public class ConfigurableIpPreferenceTest
private const string CnnPrefIPv6 = ";IPAddressPreference=IPv6First";
private const string CnnPrefIPv4 = ";IPAddressPreference=IPv4First";

private static bool IsTCPConnectionStringSetup() => !string.IsNullOrEmpty(TCPConnectionString);
private static bool IsValidDataSource()
{
SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(TCPConnectionString);
int startIdx = builder.DataSource.IndexOf(':') + 1;
int endIdx = builder.DataSource.IndexOf(',');
string serverName;
if (endIdx == -1)
{
serverName = builder.DataSource.Substring(startIdx);
}
else
{
serverName = builder.DataSource.Substring(startIdx, endIdx - startIdx);
}

List<IPAddress> ipAddresses = Dns.GetHostAddresses(serverName).ToList();
return ipAddresses.Exists(ip => ip.AddressFamily == AddressFamily.InterNetwork) &&
ipAddresses.Exists(ip => ip.AddressFamily == AddressFamily.InterNetworkV6);
}

[ConditionalTheory(nameof(IsTCPConnectionStringSetup), nameof(IsValidDataSource))]
[InlineData(CnnPrefIPv6)]
[InlineData(CnnPrefIPv4)]
[InlineData(";IPAddressPreference=UsePlatformDefault")]
public void ConfigurableIpPreference(string ipPreference)
{
using (SqlConnection connection = new SqlConnection(TCPConnectionString + ipPreference))
{
connection.Open();
Assert.Equal(ConnectionState.Open, connection.State);
Tuple<string, string, string, string> DNSInfo = connection.GetSQLDNSInfo();
if(ipPreference == CnnPrefIPv4)
{
Assert.NotNull(DNSInfo.Item2); //IPv4
Assert.Null(DNSInfo.Item3); //IPv6
}
else if(ipPreference == CnnPrefIPv6)
{
Assert.Null(DNSInfo.Item2);
Assert.NotNull(DNSInfo.Item3);
}
else
{
Assert.True((DNSInfo.Item2 != null && DNSInfo.Item3 == null) || (DNSInfo.Item2 == null && DNSInfo.Item3 != null));
}
}
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DoesHostAddressContainBothIPv4AndIPv6))]
[InlineData(CnnPrefIPv6)]
[InlineData(CnnPrefIPv4)]
public void ConfigurableIpPreferenceManagedSni(string ipPreference)
{
AppContext.SetSwitch("Switch.Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows", true);
TestConfigurableIpPreference(ipPreference);
TestCachedConfigurableIpPreference(ipPreference, DNSCachingConnString);
AppContext.SetSwitch("Switch.Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows", false);
}

private void TestConfigurableIpPreference(string ipPreference)
private void TestCachedConfigurableIpPreference(string ipPreference, string cnnString)
{
using (SqlConnection connection = new SqlConnection(DNSCachingConnString + ipPreference))
using (SqlConnection connection = new SqlConnection(cnnString + ipPreference))
{
// each successful connection updates the dns cache entry for the data source
connection.Open();
Expand All @@ -43,6 +98,7 @@ private void TestConfigurableIpPreference(string ipPreference)
const string AddrIPv6Property = "AddrIPv6";
const string FQDNProperty = "FQDN";

Assert.NotNull(dnsCacheEntry);
Assert.Equal(connection.DataSource, GetPropertyValueFromCacheEntry(FQDNProperty, dnsCacheEntry));

if (ipPreference == CnnPrefIPv4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using Xunit;

Expand Down

0 comments on commit fd6c475

Please sign in to comment.