Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update error code when certificate validation fails in managed SNI #1130

Merged
merged 3 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ internal enum SNISMUXFlags
internal class SNICommon
{
private const string s_className = nameof(SNICommon);

// Each error number maps to SNI_ERROR_* in String.resx
internal const int ConnTerminatedError = 2;
internal const int InvalidParameterError = 5;
Expand Down Expand Up @@ -220,11 +220,12 @@ internal static uint ReportSNIError(SNIProviders provider, uint nativeError, uin
/// <param name="provider">SNI provider</param>
/// <param name="sniError">SNI error code</param>
/// <param name="sniException">SNI Exception</param>
/// <param name="nativeErrorCode">Native SNI error code</param>
/// <returns></returns>
internal static uint ReportSNIError(SNIProviders provider, uint sniError, Exception sniException)
internal static uint ReportSNIError(SNIProviders provider, uint sniError, Exception sniException, uint nativeErrorCode = 0)
{
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Provider = {0}, SNI Error = {1}, Exception = {2}", args0: provider, args1: sniError, args2: sniException?.Message);
return ReportSNIError(new SNIError(provider, sniError, sniException));
return ReportSNIError(new SNIError(provider, sniError, sniException, nativeErrorCode));
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ namespace Microsoft.Data.SqlClient.SNI
/// </summary>
internal class SNIError
{
// Error numbers from native SNI implementation
internal const uint CertificateValidationErrorCode = 2148074277;

public readonly SNIProviders provider;
public readonly string errorMessage;
public readonly uint nativeError;
Expand All @@ -21,24 +24,24 @@ internal class SNIError

public SNIError(SNIProviders provider, uint nativeError, uint sniErrorCode, string errorMessage)
{
this.lineNumber = 0;
this.function = string.Empty;
lineNumber = 0;
function = string.Empty;
this.provider = provider;
this.nativeError = nativeError;
this.sniError = sniErrorCode;
sniError = sniErrorCode;
this.errorMessage = errorMessage;
this.exception = null;
exception = null;
}

public SNIError(SNIProviders provider, uint sniErrorCode, Exception sniException)
public SNIError(SNIProviders provider, uint sniErrorCode, Exception sniException, uint nativeErrorCode = 0)
{
this.lineNumber = 0;
this.function = string.Empty;
lineNumber = 0;
function = string.Empty;
this.provider = provider;
this.nativeError = 0;
this.sniError = sniErrorCode;
this.errorMessage = string.Empty;
this.exception = sniException;
nativeError = nativeErrorCode;
sniError = sniErrorCode;
errorMessage = string.Empty;
exception = sniException;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel
string firstCachedIP;
string secondCachedIP;

if (SqlConnectionIPAddressPreference.IPv6First == ipPreference) {
if (SqlConnectionIPAddressPreference.IPv6First == ipPreference)
{
firstCachedIP = cachedDNSInfo.AddrIPv6;
secondCachedIP = cachedDNSInfo.AddrIPv4;
} else {
}
else
{
firstCachedIP = cachedDNSInfo.AddrIPv4;
secondCachedIP = cachedDNSInfo.AddrIPv6;
}
Expand Down Expand Up @@ -339,8 +342,8 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);

string IPv4String = null;
string IPv6String = null;
string IPv6String = null;

// Returning null socket is handled by the caller function.
if (ipAddresses == null || ipAddresses.Length == 0)
{
Expand Down Expand Up @@ -434,7 +437,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo

// If we have already got a valid Socket, or the platform default was prefered
// we won't do the second traversal.
if (availableSocket != null || ipPreference == SqlConnectionIPAddressPreference.UsePlatformDefault)
if (availableSocket is not null || ipPreference == SqlConnectionIPAddressPreference.UsePlatformDefault)
{
break;
}
Expand Down Expand Up @@ -590,7 +593,7 @@ public override uint EnableSsl(uint options)
catch (AuthenticationException aue)
{
SqlClientEventSource.Log.TrySNITraceEvent(s_className, EventType.ERR, "Connection Id {0}, Authentication exception occurred: {1}", args0: _connectionId, args1: aue?.Message);
return ReportTcpSNIError(aue);
return ReportTcpSNIError(aue, SNIError.CertificateValidationErrorCode);
}
catch (InvalidOperationException ioe)
{
Expand Down Expand Up @@ -882,10 +885,10 @@ public override uint CheckConnection()
return TdsEnums.SNI_SUCCESS;
}

private uint ReportTcpSNIError(Exception sniException)
private uint ReportTcpSNIError(Exception sniException, uint nativeErrorCode = 0)
{
_status = TdsEnums.SNI_ERROR;
return SNICommon.ReportSNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, sniException);
return SNICommon.ReportSNIError(SNIProviders.TCP_PROV, SNICommon.InternalExceptionError, sniException, nativeErrorCode);
}

private uint ReportTcpSNIError(uint nativeError, uint sniError, string errorMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ internal void Connect(
// On Instance failure re-connect and flush SNI named instance cache.
_physicalStateObj.SniContext = SniContext.Snix_Connect;

_physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref _sniSpnBuffer, true, true, fParallel,
_physicalStateObj.CreatePhysicalSNIHandle(serverInfo.ExtendedServerName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref _sniSpnBuffer, true, true, fParallel,
_connHandler.ConnectionOptions.IPAddressPreference, FQDNforDNSCahce, ref _connHandler.pendingSQLDNSObject, integratedSecurity);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
Expand Down Expand Up @@ -1432,8 +1432,8 @@ internal SqlError ProcessSNIError(TdsParserStateObject stateObj)
SqlClientEventSource.Log.TryAdvancedTraceErrorEvent("<sc.TdsParser.ProcessSNIError |ERR|ADV > SNI Error Message. Native Error = {0}, Line Number ={1}, Function ={2}, Exception ={3}, Server = {4}",
(int)details.nativeError, (int)details.lineNumber, details.function, details.exception, _server);

return new SqlError((int)details.nativeError, 0x00, TdsEnums.FATAL_ERROR_CLASS,
_server, errorMessage, details.function, (int)details.lineNumber, details.nativeError, details.exception);
return new SqlError(infoNumber: (int)details.nativeError, errorState: 0x00, TdsEnums.FATAL_ERROR_CLASS, _server,
errorMessage, details.function, (int)details.lineNumber, win32ErrorCode: details.nativeError, details.exception);
}
finally
{
Expand Down