diff --git a/src_native/sspi_impl.cpp b/src_native/sspi_impl.cpp index c9d47f9..38c31ea 100644 --- a/src_native/sspi_impl.cpp +++ b/src_native/sspi_impl.cpp @@ -9,7 +9,16 @@ // This is in the prioritized order in terms of which package to use. The // first package from this list that's supported by the client OS will be // used to connect to the server. -char SspiImpl::s_supportedPackages[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] = +WCHAR SspiImpl::s_supportedPackages[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] = +{ + L"Negotiate", + L"Kerberos", + L"NTLM" +}; + +// This should have 1-1 correspondence with s_supportedPackages above. This is to simplify +// returning available packages in Initialize function. +char SspiImpl::s_supportedPackagesUtf8[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] = { "Negotiate", "Kerberos", @@ -17,14 +26,16 @@ char SspiImpl::s_supportedPackages[s_numSupportedPackages][SspiImpl::c_maxPackag }; // This is the default security package to use if none specified by the app. -const char* SspiImpl::s_defaultPackage = nullptr; +const WCHAR* SspiImpl::s_defaultPackage = nullptr; // Maximum token size across all packages. int SspiImpl::s_packageMaxTokenSize = -1; SspiImpl::SspiImpl(const char* spn, const char* securityPackage) : m_spn(spn), + m_spnMultiByte(), m_securityPackage(), + m_securityPackageMultiByte(), m_utEnableCannedResponse(false), m_utForceCompleteAuth(false) { @@ -47,20 +58,18 @@ SECURITY_STATUS SspiImpl::Initialize( DebugLog("%d: Worker thread: SspiImpl::Initialize.\n", GetCurrentThreadId()); errorString->assign(""); - - const int c_errorStringBufferSize = 256; char errorStringLocal[c_errorStringBufferSize]; unsigned long numPackages; - PSecPkgInfoA psecPkgInfo; + PSecPkgInfoW psecPkgInfo; - SECURITY_STATUS securityStatus = EnumerateSecurityPackagesA(&numPackages, &psecPkgInfo); + SECURITY_STATUS securityStatus = EnumerateSecurityPackagesW(&numPackages, &psecPkgInfo); if (securityStatus != SEC_E_OK) { snprintf( errorStringLocal, c_errorStringBufferSize, - "EnumerateSecurityPackagesA failed with error code: 0x%X.", + "EnumerateSecurityPackagesW failed with error code: 0x%X.", securityStatus); errorString->assign(errorStringLocal); @@ -71,9 +80,9 @@ SECURITY_STATUS SspiImpl::Initialize( { for (unsigned long packagesIndex = 0; packagesIndex < numPackages; packagesIndex++) { - if (_strcmpi(s_supportedPackages[supportedPackagesIndex], psecPkgInfo[packagesIndex].Name) == 0) + if (_wcsicmp(s_supportedPackages[supportedPackagesIndex], psecPkgInfo[packagesIndex].Name) == 0) { - availablePackages->push_back(psecPkgInfo[packagesIndex].Name); + availablePackages->push_back(s_supportedPackagesUtf8[supportedPackagesIndex]); if (s_packageMaxTokenSize < static_cast(psecPkgInfo[packagesIndex].cbMaxToken)) { s_packageMaxTokenSize = psecPkgInfo[packagesIndex].cbMaxToken; @@ -107,9 +116,9 @@ SECURITY_STATUS SspiImpl::Initialize( errorStringLocal, c_errorStringBufferSize, "No supported security package (%s, %s or %s) available on client.", - s_supportedPackages[0], - s_supportedPackages[1], - s_supportedPackages[2]); + s_supportedPackagesUtf8[0], + s_supportedPackagesUtf8[1], + s_supportedPackagesUtf8[2]); errorString->assign(errorStringLocal); return securityStatus; @@ -141,7 +150,6 @@ SECURITY_STATUS SspiImpl::GetNextBlob( errorString->assign(""); - const int c_errorStringBufferSize = 256; char errorStringLocal[c_errorStringBufferSize]; // Lifetime owned by caller. See comments in the header file for details. @@ -151,19 +159,45 @@ SECURITY_STATUS SspiImpl::GetNextBlob( if (!SecIsValidHandle(&m_credHandle)) { - const char* securityPackage; + securityStatus = ConvertUtf8ToMultiByte( + "spn", + m_spn.c_str(), + &m_spnMultiByte, + errorStringLocal, + c_errorStringBufferSize); + + if (securityStatus != S_OK) + { + errorString->assign(errorStringLocal); + return securityStatus; + } + + const WCHAR* securityPackage; if (m_securityPackage.empty()) { securityPackage = s_defaultPackage; } else { - securityPackage = m_securityPackage.c_str(); + securityStatus = ConvertUtf8ToMultiByte( + "securityPackage", + m_securityPackage.c_str(), + &m_securityPackageMultiByte, + errorStringLocal, + c_errorStringBufferSize); + + if (securityStatus != S_OK) + { + errorString->assign(errorStringLocal); + return securityStatus; + } + + securityPackage = m_securityPackageMultiByte.get(); } - securityStatus = AcquireCredentialsHandleA( + securityStatus = AcquireCredentialsHandleW( nullptr, // Principal - logged in user. - const_cast(securityPackage), // Security package to use. + const_cast(securityPackage), // Security package to use. SECPKG_CRED_OUTBOUND, // Client credential token sent to server. nullptr, // Locally unique user identifier. nullptr, // Auth data - use default credentials. @@ -177,7 +211,7 @@ SECURITY_STATUS SspiImpl::GetNextBlob( snprintf( errorStringLocal, c_errorStringBufferSize, - "AcquireCredentialsHandleA failed with error code: 0x%X.", + "AcquireCredentialsHandleW failed with error code: 0x%X.", securityStatus); errorString->assign(errorStringLocal); @@ -207,10 +241,10 @@ SECURITY_STATUS SspiImpl::GetNextBlob( ULONG contextAttr; - securityStatus = InitializeSecurityContextA( + securityStatus = InitializeSecurityContextW( &m_credHandle, // Credential handle. SecIsValidHandle(&m_ctxtHandle) ? &m_ctxtHandle : nullptr, // Context handle - input. - const_cast(m_spn.c_str()), // Service Principal name (SPN). + const_cast(m_spnMultiByte.get()), // Service Principal name (SPN). ISC_REQ_DELEGATE | ISC_REQ_MUTUAL_AUTH | ISC_REQ_INTEGRITY | ISC_REQ_EXTENDED_ERROR, // Context bit flags. 0, // Reserved - unused. @@ -230,7 +264,7 @@ SECURITY_STATUS SspiImpl::GetNextBlob( snprintf( errorStringLocal, c_errorStringBufferSize, - "InitializeSecurityContextA failed with error code: 0x%X.", + "InitializeSecurityContextW failed with error code: 0x%X.", securityStatus); errorString->assign(errorStringLocal); @@ -270,7 +304,61 @@ void SspiImpl::FreeBlob(char* blob) delete[] blob; } -void::SspiImpl::DeleteCredHandle() +// static +HRESULT SspiImpl::ConvertUtf8ToMultiByte( + const char* paramName, + const char* utf8Str, + std::unique_ptr* multiByteStr, + char* errorString, + int errorStringBufferSize) +{ + int retval = MultiByteToWideChar( + CP_UTF8, // Code page UTF8. + MB_ERR_INVALID_CHARS, // Fail on invalid characters. + utf8Str, // UTF8 string. + -1, // Indicates null-termination. + nullptr, // Output buffer, ignored when getting required buffer size. + 0); // Output buffer size, in characters set to 0 to get required buffer size. + + if (!retval) { + HRESULT hr = HRESULT_FROM_WIN32(GetLastError()); + snprintf( + errorString, + errorStringBufferSize, + "MultiByteToWideChar failed to get required buffer size for '%s'. Error code: 0x%X.", + paramName, + hr); + + return hr; + } + + // retval includes null space for terminating null character. + multiByteStr->reset(new WCHAR[retval]); + + retval = MultiByteToWideChar( + CP_UTF8, // Code page UTF8. + MB_ERR_INVALID_CHARS, // Fail on invalid characters. + utf8Str, // UTF8 string. + -1, // Indicates null-termination. + multiByteStr->get(), // Output buffer. + retval); // Output buffer size, in characters set to 0 to get required buffer size. + + if (!retval) { + HRESULT hr = HRESULT_FROM_WIN32(GetLastError()); + snprintf( + errorString, + errorStringBufferSize, + "MultiByteToWideChar failed to convert UTF8 to WideChar for '%s'. Error code: 0x%X.", + paramName, + hr); + + return hr; + } + + return S_OK; +} + +void SspiImpl::DeleteCredHandle() { if (SecIsValidHandle(&m_credHandle)) { diff --git a/src_native/sspi_impl.h b/src_native/sspi_impl.h index 6ab5fb5..e4a88f2 100644 --- a/src_native/sspi_impl.h +++ b/src_native/sspi_impl.h @@ -2,6 +2,7 @@ #define SECURITY_WIN32 +#include #include #include #include @@ -47,21 +48,34 @@ class SspiImpl SspiImpl(const SspiImpl&); SspiImpl& operator=(const SspiImpl&); + static HRESULT ConvertUtf8ToMultiByte( + const char* paramName, + const char* utf8Str, + std::unique_ptr* multiByteStr, + char* errorString, + int errorStringBufferSize); + void DeleteCredHandle(); void DeleteCtxtHandle(); static const int c_maxPackageNameLength = 32; static const int s_numSupportedPackages = 3; - static char s_supportedPackages[s_numSupportedPackages][c_maxPackageNameLength]; + static WCHAR s_supportedPackages[s_numSupportedPackages][c_maxPackageNameLength]; + static char s_supportedPackagesUtf8[s_numSupportedPackages][c_maxPackageNameLength]; - static const char* s_defaultPackage; + static const WCHAR* s_defaultPackage; static int s_packageMaxTokenSize; + static const int c_errorStringBufferSize = 256; + CredHandle m_credHandle; CtxtHandle m_ctxtHandle; std::string m_spn; + std::unique_ptr m_spnMultiByte; + std::string m_securityPackage; + std::unique_ptr m_securityPackageMultiByte; // Everything below is for unit testing purposes only. SECURITY_STATUS UtSetCannedResponse(