Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ydb/mvp/oidc_proxy/bin/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

int main(int argc, char **argv) {
try {
return NMVP::TMVP(argc, argv).Run();
return NMVP::NOIDC::TMVP(argc, argv).Run();
} catch (const yexception& e) {
Cerr << "Caught exception: " << e.what() << Endl;
return 1;
Expand Down
12 changes: 8 additions & 4 deletions ydb/mvp/oidc_proxy/mvp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
#include "mvp.h"
#include "oidc_client.h"

using namespace NMVP;
NActors::IActor* CreateMemProfiler();

namespace NMVP {
namespace NOIDC {

namespace {

Expand All @@ -42,7 +45,7 @@ TString AddSchemeToUserToken(const TString& token, const TString& scheme) {
const ui16 TMVP::DefaultHttpPort = 8788;
const ui16 TMVP::DefaultHttpsPort = 8789;

const TString& NMVP::GetEServiceName(NActors::NLog::EComponent component) {
const TString& GetEServiceName(NActors::NLog::EComponent component) {
static const TString loggerName("LOGGER");
static const TString mvpName("MVP");
static const TString grpcName("GRPC");
Expand All @@ -66,8 +69,6 @@ void TMVP::OnTerminate(int) {
AtomicSet(Quit, true);
}

NActors::IActor* CreateMemProfiler();

int TMVP::Init() {
ActorSystem.Start();

Expand Down Expand Up @@ -415,3 +416,6 @@ THolder<NActors::TActorSystemSetup> TMVP::BuildActorSystemSetup(int argc, char**
}

TAtomic TMVP::Quit = false;

} // NOIDC
} // NMVP
4 changes: 3 additions & 1 deletion ydb/mvp/oidc_proxy/mvp.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
#include <library/cpp/deprecated/atomic/atomic.h>
#include <util/system/rwlock.h>
#include <contrib/libs/yaml-cpp/include/yaml-cpp/yaml.h>
#include "openid_connect.h"
#include "oidc_settings.h"

namespace NMVP {
namespace NOIDC {

const TString& GetEServiceName(NActors::NLog::EComponent component);

Expand Down Expand Up @@ -71,4 +72,5 @@ class TMVP {
int Shutdown();
};

} // namespace NOIDC
} // namespace NMVP
10 changes: 8 additions & 2 deletions ydb/mvp/oidc_proxy/oidc_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
#include "oidc_protected_page_handler.h"
#include "oidc_session_create_handler.h"

namespace NMVP {
namespace NOIDC {

void InitOIDC(NActors::TActorSystem& actorSystem,
const NActors::TActorId& httpProxyId,
const TOpenIdConnectSettings& settings) {
actorSystem.Send(httpProxyId, new NHttp::TEvHttpProxy::TEvRegisterHandler(
"/auth/callback",
actorSystem.Register(new NMVP::TSessionCreateHandler(httpProxyId, settings))
actorSystem.Register(new TSessionCreateHandler(httpProxyId, settings))
)
);

actorSystem.Send(httpProxyId, new NHttp::TEvHttpProxy::TEvRegisterHandler(
"/",
actorSystem.Register(new NMVP::TProtectedPageHandler(httpProxyId, settings))
actorSystem.Register(new TProtectedPageHandler(httpProxyId, settings))
)
);
}

} // NOIDC
} // NMVP
14 changes: 12 additions & 2 deletions ydb/mvp/oidc_proxy/oidc_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
#pragma once
namespace NActors {

#include <ydb/mvp/core/core_ydb.h>
#include "openid_connect.h"
class TActorSystem;
struct TActorId;

} // NActors
namespace NMVP {
namespace NOIDC {

struct TOpenIdConnectSettings;

void InitOIDC(NActors::TActorSystem& actorSystem, const NActors::TActorId& httpProxyId, const TOpenIdConnectSettings& settings);

} // NOIDC
} // NMVP
238 changes: 238 additions & 0 deletions ydb/mvp/oidc_proxy/oidc_protected_page.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#include <ydb/library/actors/core/actor.h>
#include <ydb/library/actors/http/http.h>
#include <ydb/mvp/core/mvp_log.h>
#include <ydb/core/util/wildcard.h>
#include "openid_connect.h"
#include "oidc_protected_page.h"

namespace NMVP {
namespace NOIDC {

THandlerSessionServiceCheck::THandlerSessionServiceCheck(const NActors::TActorId& sender,
const NHttp::THttpIncomingRequestPtr& request,
const NActors::TActorId& httpProxyId,
const TOpenIdConnectSettings& settings)
: Sender(sender)
, Request(request)
, HttpProxyId(httpProxyId)
, Settings(settings)
, ProtectedPageUrl(Request->URL.SubStr(1))
{}

void THandlerSessionServiceCheck::Bootstrap(const NActors::TActorContext& ctx) {
if (!CheckRequestedHost()) {
ctx.Send(Sender, new NHttp::TEvHttpProxy::TEvHttpOutgoingResponse(CreateResponseForbiddenHost()));
Die(ctx);
return;
}
NHttp::THeaders headers(Request->Headers);
IsAjaxRequest = DetectAjaxRequest(headers);
TStringBuf authHeader = headers.Get(AUTH_HEADER_NAME);
if (Request->Method == "OPTIONS" || IsAuthorizedRequest(authHeader)) {
ForwardUserRequest(TString(authHeader), ctx);
} else {
StartOidcProcess(ctx);
}
}

void THandlerSessionServiceCheck::HandleProxy(NHttp::TEvHttpProxy::TEvHttpIncomingResponse::TPtr event, const NActors::TActorContext& ctx) {
NHttp::THttpOutgoingResponsePtr httpResponse;
if (event->Get()->Response != nullptr) {
NHttp::THttpIncomingResponsePtr response = event->Get()->Response;
LOG_DEBUG_S(ctx, EService::MVP, "Incoming response for protected resource: " << response->Status);
if (NeedSendSecureHttpRequest(response)) {
SendSecureHttpRequest(response, ctx);
return;
}
NHttp::THeadersBuilder headers = GetResponseHeaders(response);
TStringBuf contentType = headers.Get("Content-Type").NextTok(';');
if (contentType == "text/html") {
TString newBody = FixReferenceInHtml(response->Body, response->GetRequest()->Host);
httpResponse = Request->CreateResponse( response->Status, response->Message, headers, newBody);
} else {
httpResponse = Request->CreateResponse( response->Status, response->Message, headers, response->Body);
}
} else {
static constexpr size_t MAX_LOGGED_SIZE = 1024;
LOG_DEBUG_S(ctx, EService::MVP, "Can not process request to protected resource:\n" << event->Get()->Request->GetRawData().substr(0, MAX_LOGGED_SIZE));
httpResponse = CreateResponseForNotExistingResponseFromProtectedResource(event->Get()->GetError());
}
ctx.Send(Sender, new NHttp::TEvHttpProxy::TEvHttpOutgoingResponse(httpResponse));
Die(ctx);
}

bool THandlerSessionServiceCheck::CheckRequestedHost() {
size_t pos = ProtectedPageUrl.find('/');
if (pos == TString::npos) {
return false;
}
TStringBuf scheme, host, uri;
if (!NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri)) {
return false;
}
if (!scheme.empty() && (scheme != "http" && scheme != "https")) {
return false;
}
RequestedPageScheme = scheme;
auto it = std::find_if(Settings.AllowedProxyHosts.cbegin(), Settings.AllowedProxyHosts.cend(), [&host] (const TString& wildcard) {
return NKikimr::IsMatchesWildcard(host, wildcard);
});
return it != Settings.AllowedProxyHosts.cend();
}

bool THandlerSessionServiceCheck::IsAuthorizedRequest(TStringBuf authHeader) {
if (authHeader.empty()) {
return false;
}
return to_lower(ToString(authHeader)).StartsWith(IAM_TOKEN_SCHEME_LOWER);
}

void THandlerSessionServiceCheck::ForwardUserRequest(TStringBuf authHeader, const NActors::TActorContext& ctx, bool secure) {
LOG_DEBUG_S(ctx, EService::MVP, "Forward user request bypass OIDC");
NHttp::THttpOutgoingRequestPtr httpRequest = NHttp::THttpOutgoingRequest::CreateRequest(Request->Method, ProtectedPageUrl);
ForwardRequestHeaders(httpRequest);
if (!authHeader.empty()) {
httpRequest->Set(AUTH_HEADER_NAME, authHeader);
}
if (Request->HaveBody()) {
httpRequest->SetBody(Request->Body);
}
if (RequestedPageScheme.empty()) {
httpRequest->Secure = secure;
}
ctx.Send(HttpProxyId, new NHttp::TEvHttpProxy::TEvHttpOutgoingRequest(httpRequest));
}

TString THandlerSessionServiceCheck::FixReferenceInHtml(TStringBuf html, TStringBuf host, TStringBuf findStr) {
TStringBuilder result;
size_t n = html.find(findStr);
if (n == TStringBuf::npos) {
return TString(html);
}
size_t len = findStr.length() + 1;
size_t pos = 0;
while (n != TStringBuf::npos) {
result << html.SubStr(pos, n + len - pos);
if (html[n + len] == '/') {
result << "/" << host;
if (html[n + len + 1] == '\'' || html[n + len + 1] == '\"') {
result << "/internal";
n++;
}
}
pos = n + len;
n = html.find(findStr, pos);
}
result << html.SubStr(pos);
return result;
}

TString THandlerSessionServiceCheck::FixReferenceInHtml(TStringBuf html, TStringBuf host) {
TStringBuf findString = "href=";
auto result = FixReferenceInHtml(html, host, findString);
findString = "src=";
return FixReferenceInHtml(result, host, findString);
}

void THandlerSessionServiceCheck::ForwardRequestHeaders(NHttp::THttpOutgoingRequestPtr& request) const {
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
"Connection",
"Accept-Language",
"Cache-Control",
"Sec-Fetch-Dest",
"Sec-Fetch-Mode",
"Sec-Fetch-Site",
"Sec-Fetch-User",
"Upgrade-Insecure-Requests",
"Content-Type",
"Origin"
};
NHttp::THeadersBuilder headers(Request->Headers);
for (const auto& header : HEADERS_WHITE_LIST) {
if (headers.Has(header)) {
request->Set(header, headers.Get(header));
}
}
request->Set("Accept-Encoding", "deflate");
}

NHttp::THeadersBuilder THandlerSessionServiceCheck::GetResponseHeaders(const NHttp::THttpIncomingResponsePtr& response) {
static const TVector<TStringBuf> HEADERS_WHITE_LIST = {
"Content-Type",
"Connection",
"X-Worker-Name",
"Set-Cookie",
"Access-Control-Allow-Origin",
"Access-Control-Allow-Credentials",
"Access-Control-Allow-Headers",
"Access-Control-Allow-Methods"
};
NHttp::THeadersBuilder headers(response->Headers);
NHttp::THeadersBuilder resultHeaders;
for (const auto& header : HEADERS_WHITE_LIST) {
if (headers.Has(header)) {
resultHeaders.Set(header, headers.Get(header));
}
}
static const TString LOCATION_HEADER_NAME = "Location";
if (headers.Has(LOCATION_HEADER_NAME)) {
resultHeaders.Set(LOCATION_HEADER_NAME, GetFixedLocationHeader(headers.Get(LOCATION_HEADER_NAME)));
}
return resultHeaders;
}

void THandlerSessionServiceCheck::SendSecureHttpRequest(const NHttp::THttpIncomingResponsePtr& response, const NActors::TActorContext& ctx) {
NHttp::THttpOutgoingRequestPtr request = response->GetRequest();
LOG_DEBUG_S(ctx, EService::MVP, "Try to send request to HTTPS port");
NHttp::THeadersBuilder headers {request->Headers};
ForwardUserRequest(headers.Get(AUTH_HEADER_NAME), ctx, true);
}

TString THandlerSessionServiceCheck::GetFixedLocationHeader(TStringBuf location) {
TStringBuf scheme, host, uri;
NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri);
if (location.StartsWith("//")) {
return TStringBuilder() << '/' << (scheme.empty() ? "" : TString(scheme) + "://") << location.SubStr(2);
} else if (location.StartsWith('/')) {
return TStringBuilder() << '/'
<< (scheme.empty() ? "" : TString(scheme) + "://")
<< host << location;
} else {
TStringBuf locScheme, locHost, locUri;
NHttp::CrackURL(location, locScheme, locHost, locUri);
if (!locScheme.empty()) {
return TStringBuilder() << '/' << location;
}
}
return TString(location);
}

NHttp::THttpOutgoingResponsePtr THandlerSessionServiceCheck::CreateResponseForbiddenHost() {
NHttp::THeadersBuilder headers;
headers.Set("Content-Type", "text/html");
SetCORS(Request, &headers);

TStringBuf scheme, host, uri;
NHttp::CrackURL(ProtectedPageUrl, scheme, host, uri);
TStringBuilder html;
html << "<html><head><title>403 Forbidden</title></head><body bgcolor=\"white\"><center><h1>";
html << "403 Forbidden host: " << host;
html << "</h1></center></body></html>";

return Request->CreateResponse("403", "Forbidden", headers, html);
}

NHttp::THttpOutgoingResponsePtr THandlerSessionServiceCheck::CreateResponseForNotExistingResponseFromProtectedResource(const TString& errorMessage) {
NHttp::THeadersBuilder headers;
headers.Set("Content-Type", "text/html");
SetCORS(Request, &headers);

TStringBuilder html;
html << "<html><head><title>400 Bad Request</title></head><body bgcolor=\"white\"><center><h1>";
html << "400 Bad Request. Can not process request to protected resource: " << errorMessage;
html << "</h1></center></body></html>";
return Request->CreateResponse("400", "Bad Request", headers, html);
}

} // NOIDC
} // NMVP
Loading