Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom root cert pinning support #1194

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 162 additions & 1 deletion lib/http/HttpClient_WinInet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,76 @@ class WinInetRequestWrapper
return true;
}

bool ValidateCustomRootCert()
{
if(!m_parent.IsCustomRootCheckRequired()) {
return true;
}
// Pointer to certificate chain obtained via InternetQueryOption :
// Ref. https://blogs.msdn.microsoft.com/alejacma/2012/01/18/how-to-use-internet_option_server_cert_chain_context-with-internetqueryoption-in-c/
PCCERT_CHAIN_CONTEXT pCertChainCtx = nullptr;
DWORD dwCertChainContextSize = sizeof(PCCERT_CHAIN_CONTEXT);
// Proceed to process the result if API call succeeds. That option is available in MSIE 8.x+ since Windows 7.1 and Win Server 2008 R2.
// In case if API call fails, then proceed without cert validation. This behavior is identical to default old behavior to avoid
// regressions for downlevel OS.
if (::InternetQueryOption(m_hWinInetRequest, INTERNET_OPTION_SERVER_CERT_CHAIN_CONTEXT, (LPVOID)&pCertChainCtx, &dwCertChainContextSize))
{
bool fError = false;
DWORD thumbPrintSize = CERTIFICATE_THUMBPRINT_SHA256_SIZE;
std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE> thumbPrint;
do
{
if(pCertChainCtx == nullptr || pCertChainCtx->TrustStatus.dwErrorStatus != CERT_TRUST_NO_ERROR || pCertChainCtx->cChain == 0 || pCertChainCtx->rgpChain == nullptr)
{
fError = true;
break;
}
// Select the root certificate context in certificate chain.
PCERT_SIMPLE_CHAIN pCertChain = pCertChainCtx->rgpChain[pCertChainCtx->cChain - 1];
if(pCertChain->rgpElement == nullptr || pCertChain->cElement == 0)
{
fError = true;
break;
}
char aStringBuffer[MAX_PATH];
// For each certificate other than the root certificate validate the Subject Org
for(DWORD i = 0; i < pCertChain->cElement - 1; i++) {
PCCERT_CONTEXT pCertContext = pCertChain->rgpElement[i]->pCertContext;
if(CertGetNameStringA(pCertContext, CERT_NAME_ATTR_TYPE, 0, (void*)szOID_ORGANIZATION_NAME, aStringBuffer, MAX_PATH) == 1 ||
!m_parent.IsTrustedSubjectOrg(aStringBuffer)) {
fError = true;
break;
}
}
if(fError) {
break;
}
PCCERT_CONTEXT pRootCertContext = pCertChain->rgpElement[pCertChain->cElement - 1]->pCertContext;
if(pRootCertContext == nullptr || pRootCertContext->cbCertEncoded == 0 || pRootCertContext->pbCertEncoded == nullptr)
{
fError = true;
break;
}
if(!CryptHashCertificate(0, CALG_SHA_256, 0, pRootCertContext->pbCertEncoded, pRootCertContext->cbCertEncoded, (BYTE*)&thumbPrint[0], &thumbPrintSize) && thumbPrintSize != CERTIFICATE_THUMBPRINT_SHA256_SIZE) {
fError = true;
break;
}

} while(0);
if (pCertChainCtx != nullptr)
{
CertFreeCertificateChain(pCertChainCtx);
}
return fError ? false : m_parent.IsTrustedRootCert(thumbPrint);
}
else
{
// Downlevel OS prior to Win 7 and Win 2008 Server R2 do not support cert chain retrieval
LOG_TRACE("InternetQueryOption() failed to obtain cert chain %x", GetLastError());
}
return true;
}

// Asynchronously send HTTP request and invoke response callback.
// Ownership semantics: send(...) method self-destroys *this* upon
// receiving WinInet callback. There must be absolutely no methods
Expand Down Expand Up @@ -348,6 +418,22 @@ class WinInetRequestWrapper
}
}

if (m_parent.IsMsRootCheckRequired())
{ /* Perform optional MS Root certificate check for certain end-point URLs */
if (!isMsRootCert())
{
// Request cannot be completed: end-point certificate is not MS-Rooted
dwError = ERROR_INTERNET_SEC_INVALID_CERT;
}
}
else if (m_parent.IsCustomRootCheckRequired())
{ /* Perform optional Custom Root Certificate certificate check for certain end-point URLs */
if (!ValidateCustomRootCert())
{
// Request cannot be completed: end-point certificate is not Custom-Rooted
dwError = ERROR_INTERNET_SEC_INVALID_CERT;
}
}
std::unique_ptr<SimpleHttpResponse> response(new SimpleHttpResponse(m_id));

// SUCCESS with no IO_PENDING means we're done with the response body: try to parse the response headers.
Expand Down Expand Up @@ -467,9 +553,11 @@ class WinInetRequestWrapper
unsigned HttpClient_WinInet::s_nextRequestId = 0;

HttpClient_WinInet::HttpClient_WinInet() :
m_msRootCheck(false)
m_msRootCheck(false),
m_fcustomRootCheck(false)
{
m_hInternet = ::InternetOpen(NULL, INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, INTERNET_FLAG_ASYNC);
m_customRootCerts = {};
}

HttpClient_WinInet::~HttpClient_WinInet()
Expand Down Expand Up @@ -562,6 +650,79 @@ bool HttpClient_WinInet::IsMsRootCheckRequired()
return m_msRootCheck;
}

/// <summary>
/// Enforces Custom root server certificate check.
/// </summary>
/// <param name="enforceMsRoot">if set to <c>true</c> [enforce verification that server cert is in trusted list].</param>
void HttpClient_WinInet::SetCustomRootCheck(bool enforceCustomRoot)
{
m_fcustomRootCheck = enforceCustomRoot;
}

/// <summary>
/// Determines whether Custom Root certificate check required.
/// </summary>
/// <returns>
/// <c>true</c> if [Root certificate is one in the trusted list]; otherwise, <c>false</c>.
/// </returns>
bool HttpClient_WinInet::IsCustomRootCheckRequired()
{
return m_fcustomRootCheck && m_customRootCerts.size() > 0;
}

bool HttpClient_WinInet::IsTrustedRootCert(std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE> &aCertThumbprint)
{
if(!m_fcustomRootCheck || m_customRootCerts.size() == 0) {
return true;
}
bool fTrusted = false;
for (auto &aRoot : m_customRootCerts) {
if(aRoot == aCertThumbprint) {
fTrusted = true;
break;
}
}
return fTrusted;
}

bool HttpClient_WinInet::AddCustomRootCertSHA256Thumbprint(std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE> &aCertThumbprint)
{
try {
m_customRootCerts.push_back(aCertThumbprint);
return true;
} catch (...) {
return false;
}
}

bool HttpClient_WinInet::AddCustomTrustedSubjectOrg(std::string& aTrustedOrg)
{
try {
m_customTrustedSubjectOrgs.push_back(aTrustedOrg);
return true;
} catch (...) {
return false;
}
}

bool HttpClient_WinInet::IsTrustedSubjectOrg(char* aOrg)
{
if(!m_fcustomRootCheck || m_customTrustedSubjectOrgs.size() == 0) {
return true;
}
if(aOrg == nullptr) {
return false;
}
bool fTrusted = false;
for (auto &aTrustedOrg : m_customTrustedSubjectOrgs) {
if(aOrg == aTrustedOrg) {
fTrusted = true;
break;
}
}
return fTrusted;
}

} MAT_NS_END
#pragma warning(pop)
#endif // HAVE_MAT_DEFAULT_HTTP_CLIENT
Expand Down
14 changes: 13 additions & 1 deletion lib/http/HttpClient_WinInet.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "IHttpClient.hpp"
#include "pal/PAL.hpp"
#include <cstddef>

#include "ILogManager.hpp"

Expand All @@ -17,6 +18,7 @@ namespace MAT_NS_BEGIN {
#ifndef _WININET_
typedef void* HINTERNET;
#endif
#define CERTIFICATE_THUMBPRINT_SHA256_SIZE 32

class WinInetRequestWrapper;

Expand All @@ -34,7 +36,14 @@ class HttpClient_WinInet : public IHttpClient {
void SetMsRootCheck(bool enforceMsRoot);
bool IsMsRootCheckRequired();

protected:
void SetCustomRootCheck(bool enforceCustomRoot);
bool IsCustomRootCheckRequired();
bool AddCustomRootCertSHA256Thumbprint(std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE>& aCertThumbprint);
bool IsTrustedRootCert(std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE>& aCertThumbprint);
bool AddCustomTrustedSubjectOrg(std::string& aTrustedOrg);
bool IsTrustedSubjectOrg(char* aTrustedOrg);

protected:
void erase(std::string const& id);

protected:
Expand All @@ -43,6 +52,9 @@ class HttpClient_WinInet : public IHttpClient {
std::map<std::string, WinInetRequestWrapper*> m_requests;
static unsigned s_nextRequestId;
bool m_msRootCheck;
bool m_fcustomRootCheck;
std::vector<std::array<std::byte, CERTIFICATE_THUMBPRINT_SHA256_SIZE>> m_customRootCerts;
std::vector<std::string> m_customTrustedSubjectOrgs;
friend class WinInetRequestWrapper;
};

Expand Down
Loading