Skip to content

Commit

Permalink
Switch from A version to W version of SSPI calls. (tvrprasad#30)
Browse files Browse the repository at this point in the history
This will allow for being able to specify International Domain Names.

tvrprasad#5
  • Loading branch information
tvrprasad authored Dec 20, 2016
1 parent 77c665e commit 9820ecd
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 24 deletions.
132 changes: 110 additions & 22 deletions src_native/sspi_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,33 @@
// This is in the prioritized order in terms of which package to use. The
// first package from this list that's supported by the client OS will be
// used to connect to the server.
char SspiImpl::s_supportedPackages[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] =
WCHAR SspiImpl::s_supportedPackages[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] =
{
L"Negotiate",
L"Kerberos",
L"NTLM"
};

// This should have 1-1 correspondence with s_supportedPackages above. This is to simplify
// returning available packages in Initialize function.
char SspiImpl::s_supportedPackagesUtf8[s_numSupportedPackages][SspiImpl::c_maxPackageNameLength] =
{
"Negotiate",
"Kerberos",
"NTLM"
};

// This is the default security package to use if none specified by the app.
const char* SspiImpl::s_defaultPackage = nullptr;
const WCHAR* SspiImpl::s_defaultPackage = nullptr;

// Maximum token size across all packages.
int SspiImpl::s_packageMaxTokenSize = -1;

SspiImpl::SspiImpl(const char* spn, const char* securityPackage) :
m_spn(spn),
m_spnMultiByte(),
m_securityPackage(),
m_securityPackageMultiByte(),
m_utEnableCannedResponse(false),
m_utForceCompleteAuth(false)
{
Expand All @@ -47,20 +58,18 @@ SECURITY_STATUS SspiImpl::Initialize(
DebugLog("%d: Worker thread: SspiImpl::Initialize.\n", GetCurrentThreadId());

errorString->assign("");

const int c_errorStringBufferSize = 256;
char errorStringLocal[c_errorStringBufferSize];

unsigned long numPackages;
PSecPkgInfoA psecPkgInfo;
PSecPkgInfoW psecPkgInfo;

SECURITY_STATUS securityStatus = EnumerateSecurityPackagesA(&numPackages, &psecPkgInfo);
SECURITY_STATUS securityStatus = EnumerateSecurityPackagesW(&numPackages, &psecPkgInfo);
if (securityStatus != SEC_E_OK)
{
snprintf(
errorStringLocal,
c_errorStringBufferSize,
"EnumerateSecurityPackagesA failed with error code: 0x%X.",
"EnumerateSecurityPackagesW failed with error code: 0x%X.",
securityStatus);

errorString->assign(errorStringLocal);
Expand All @@ -71,9 +80,9 @@ SECURITY_STATUS SspiImpl::Initialize(
{
for (unsigned long packagesIndex = 0; packagesIndex < numPackages; packagesIndex++)
{
if (_strcmpi(s_supportedPackages[supportedPackagesIndex], psecPkgInfo[packagesIndex].Name) == 0)
if (_wcsicmp(s_supportedPackages[supportedPackagesIndex], psecPkgInfo[packagesIndex].Name) == 0)
{
availablePackages->push_back(psecPkgInfo[packagesIndex].Name);
availablePackages->push_back(s_supportedPackagesUtf8[supportedPackagesIndex]);
if (s_packageMaxTokenSize < static_cast<int>(psecPkgInfo[packagesIndex].cbMaxToken))
{
s_packageMaxTokenSize = psecPkgInfo[packagesIndex].cbMaxToken;
Expand Down Expand Up @@ -107,9 +116,9 @@ SECURITY_STATUS SspiImpl::Initialize(
errorStringLocal,
c_errorStringBufferSize,
"No supported security package (%s, %s or %s) available on client.",
s_supportedPackages[0],
s_supportedPackages[1],
s_supportedPackages[2]);
s_supportedPackagesUtf8[0],
s_supportedPackagesUtf8[1],
s_supportedPackagesUtf8[2]);

errorString->assign(errorStringLocal);
return securityStatus;
Expand Down Expand Up @@ -141,7 +150,6 @@ SECURITY_STATUS SspiImpl::GetNextBlob(

errorString->assign("");

const int c_errorStringBufferSize = 256;
char errorStringLocal[c_errorStringBufferSize];

// Lifetime owned by caller. See comments in the header file for details.
Expand All @@ -151,19 +159,45 @@ SECURITY_STATUS SspiImpl::GetNextBlob(

if (!SecIsValidHandle(&m_credHandle))
{
const char* securityPackage;
securityStatus = ConvertUtf8ToMultiByte(
"spn",
m_spn.c_str(),
&m_spnMultiByte,
errorStringLocal,
c_errorStringBufferSize);

if (securityStatus != S_OK)
{
errorString->assign(errorStringLocal);
return securityStatus;
}

const WCHAR* securityPackage;
if (m_securityPackage.empty())
{
securityPackage = s_defaultPackage;
}
else
{
securityPackage = m_securityPackage.c_str();
securityStatus = ConvertUtf8ToMultiByte(
"securityPackage",
m_securityPackage.c_str(),
&m_securityPackageMultiByte,
errorStringLocal,
c_errorStringBufferSize);

if (securityStatus != S_OK)
{
errorString->assign(errorStringLocal);
return securityStatus;
}

securityPackage = m_securityPackageMultiByte.get();
}

securityStatus = AcquireCredentialsHandleA(
securityStatus = AcquireCredentialsHandleW(
nullptr, // Principal - logged in user.
const_cast<char*>(securityPackage), // Security package to use.
const_cast<WCHAR*>(securityPackage), // Security package to use.
SECPKG_CRED_OUTBOUND, // Client credential token sent to server.
nullptr, // Locally unique user identifier.
nullptr, // Auth data - use default credentials.
Expand All @@ -177,7 +211,7 @@ SECURITY_STATUS SspiImpl::GetNextBlob(
snprintf(
errorStringLocal,
c_errorStringBufferSize,
"AcquireCredentialsHandleA failed with error code: 0x%X.",
"AcquireCredentialsHandleW failed with error code: 0x%X.",
securityStatus);

errorString->assign(errorStringLocal);
Expand Down Expand Up @@ -207,10 +241,10 @@ SECURITY_STATUS SspiImpl::GetNextBlob(

ULONG contextAttr;

securityStatus = InitializeSecurityContextA(
securityStatus = InitializeSecurityContextW(
&m_credHandle, // Credential handle.
SecIsValidHandle(&m_ctxtHandle) ? &m_ctxtHandle : nullptr, // Context handle - input.
const_cast<char*>(m_spn.c_str()), // Service Principal name (SPN).
const_cast<WCHAR*>(m_spnMultiByte.get()), // Service Principal name (SPN).
ISC_REQ_DELEGATE | ISC_REQ_MUTUAL_AUTH | ISC_REQ_INTEGRITY | ISC_REQ_EXTENDED_ERROR,
// Context bit flags.
0, // Reserved - unused.
Expand All @@ -230,7 +264,7 @@ SECURITY_STATUS SspiImpl::GetNextBlob(
snprintf(
errorStringLocal,
c_errorStringBufferSize,
"InitializeSecurityContextA failed with error code: 0x%X.",
"InitializeSecurityContextW failed with error code: 0x%X.",
securityStatus);

errorString->assign(errorStringLocal);
Expand Down Expand Up @@ -270,7 +304,61 @@ void SspiImpl::FreeBlob(char* blob)
delete[] blob;
}

void::SspiImpl::DeleteCredHandle()
// static
HRESULT SspiImpl::ConvertUtf8ToMultiByte(
const char* paramName,
const char* utf8Str,
std::unique_ptr<WCHAR[]>* multiByteStr,
char* errorString,
int errorStringBufferSize)
{
int retval = MultiByteToWideChar(
CP_UTF8, // Code page UTF8.
MB_ERR_INVALID_CHARS, // Fail on invalid characters.
utf8Str, // UTF8 string.
-1, // Indicates null-termination.
nullptr, // Output buffer, ignored when getting required buffer size.
0); // Output buffer size, in characters set to 0 to get required buffer size.

if (!retval) {
HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
snprintf(
errorString,
errorStringBufferSize,
"MultiByteToWideChar failed to get required buffer size for '%s'. Error code: 0x%X.",
paramName,
hr);

return hr;
}

// retval includes null space for terminating null character.
multiByteStr->reset(new WCHAR[retval]);

retval = MultiByteToWideChar(
CP_UTF8, // Code page UTF8.
MB_ERR_INVALID_CHARS, // Fail on invalid characters.
utf8Str, // UTF8 string.
-1, // Indicates null-termination.
multiByteStr->get(), // Output buffer.
retval); // Output buffer size, in characters set to 0 to get required buffer size.

if (!retval) {
HRESULT hr = HRESULT_FROM_WIN32(GetLastError());
snprintf(
errorString,
errorStringBufferSize,
"MultiByteToWideChar failed to convert UTF8 to WideChar for '%s'. Error code: 0x%X.",
paramName,
hr);

return hr;
}

return S_OK;
}

void SspiImpl::DeleteCredHandle()
{
if (SecIsValidHandle(&m_credHandle))
{
Expand Down
18 changes: 16 additions & 2 deletions src_native/sspi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#define SECURITY_WIN32

#include <memory>
#include <Windows.h>
#include <Sspi.h>
#include <string>
Expand Down Expand Up @@ -47,21 +48,34 @@ class SspiImpl
SspiImpl(const SspiImpl&);
SspiImpl& operator=(const SspiImpl&);

static HRESULT ConvertUtf8ToMultiByte(
const char* paramName,
const char* utf8Str,
std::unique_ptr<WCHAR[]>* multiByteStr,
char* errorString,
int errorStringBufferSize);

void DeleteCredHandle();
void DeleteCtxtHandle();

static const int c_maxPackageNameLength = 32;
static const int s_numSupportedPackages = 3;
static char s_supportedPackages[s_numSupportedPackages][c_maxPackageNameLength];
static WCHAR s_supportedPackages[s_numSupportedPackages][c_maxPackageNameLength];
static char s_supportedPackagesUtf8[s_numSupportedPackages][c_maxPackageNameLength];

static const char* s_defaultPackage;
static const WCHAR* s_defaultPackage;
static int s_packageMaxTokenSize;

static const int c_errorStringBufferSize = 256;

CredHandle m_credHandle;
CtxtHandle m_ctxtHandle;

std::string m_spn;
std::unique_ptr<WCHAR[]> m_spnMultiByte;

std::string m_securityPackage;
std::unique_ptr<WCHAR[]> m_securityPackageMultiByte;

// Everything below is for unit testing purposes only.
SECURITY_STATUS UtSetCannedResponse(
Expand Down

0 comments on commit 9820ecd

Please sign in to comment.