Skip to content

Commit

Permalink
chore: gcs fixes (#330)
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Gershman <romange@gmail.com>
  • Loading branch information
romange committed Oct 28, 2024
1 parent 74a5f95 commit f102aa0
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 101 deletions.
47 changes: 45 additions & 2 deletions examples/gcs_demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,32 @@ using absl::GetFlag;

ABSL_FLAG(string, bucket, "", "");
ABSL_FLAG(string, prefix, "", "");
ABSL_FLAG(uint32_t, write, 0, "");
ABSL_FLAG(uint32_t, write, 0, "If write > 0, then write this many files to GCS");
ABSL_FLAG(uint32_t, read, 0, "If read > 0, then read this many files from GCS");
ABSL_FLAG(uint32_t, connect_ms, 2000, "");
ABSL_FLAG(bool, epoll, false, "Whether to use epoll instead of io_uring");

static io::Result<string> ReadToString(io::ReadonlyFile* file) {
string res_str;
while (true) {
constexpr size_t kBufSize = 1U << 20;
size_t offset = res_str.size();
res_str.resize(offset + kBufSize);
io::MutableBytes mb{reinterpret_cast<uint8_t*>(res_str.data() + offset),
kBufSize};
io::Result<size_t> res = file->Read(offset, mb);
if (!res) {
return nonstd::make_unexpected(res.error());
}
size_t read_sz = *res;
if (read_sz < kBufSize) {
res_str.resize(offset + read_sz);
break;
}
}
return res_str;
}

void Run(SSL_CTX* ctx) {
fb2::ProactorBase* pb = fb2::ProactorBase::me();
cloud::GCPCredsProvider provider;
Expand All @@ -38,10 +60,14 @@ void Run(SSL_CTX* ctx) {
if (GetFlag(FLAGS_write) > 0) {
auto src = io::ReadFileToString("/proc/self/exe");
CHECK(src);
LOG(INFO) << "Writing " << src->size() << " bytes to " << prefix;
for (unsigned i = 0; i < GetFlag(FLAGS_write); ++i) {
string dest_key = absl::StrCat(prefix, "_", i);
cloud::GcsWriteFileOptions opts;
opts.creds_provider = &provider;
opts.pool = conn_pool;
io::Result<io::WriteFile*> dest_res =
cloud::OpenWriteGcsFile(bucket, dest_key, &provider, conn_pool);
cloud::OpenWriteGcsFile(bucket, dest_key, opts);
CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message();
unique_ptr<io::WriteFile> dest(*dest_res);
error_code ec = dest->Write(*src);
Expand All @@ -50,6 +76,23 @@ void Run(SSL_CTX* ctx) {
CHECK(!ec);
CONSOLE_INFO << "Written " << dest_key;
}
} else if (GetFlag(FLAGS_read) > 0) {
for (unsigned i = 0; i < GetFlag(FLAGS_read); ++i) {
string dest_key = prefix;
cloud::GcsReadFileOptions opts;
opts.creds_provider = &provider;
opts.pool = conn_pool;
io::Result<io::ReadonlyFile*> dest_res =
cloud::OpenReadGcsFile(bucket, dest_key, opts);
CHECK(dest_res) << "Could not open " << dest_key << " " << dest_res.error().message();
unique_ptr<io::ReadonlyFile> dest(*dest_res);
io::Result<string> dest_str = ReadToString(dest.get());
if (dest_str) {
CONSOLE_INFO << "Read " << dest_str->size() << " bytes from " << dest_key;
} else {
LOG(ERROR) << "Error reading " << dest_key << " " << dest_str.error().message();
}
}
} else {
auto cb = [](cloud::GCS::ObjectItem item) {
cout << "Object: " << item.key << ", size: " << item.size << endl;
Expand Down
1 change: 0 additions & 1 deletion util/cloud/gcp/gcp_creds_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class GCPCredsProvider {
bool use_instance_metadata_ = false;
unsigned connect_ms_ = 0;

fb2::ProactorBase* pb_ = nullptr;
std::string account_id_;
std::string project_id_;

Expand Down
47 changes: 30 additions & 17 deletions util/cloud/gcp/gcp_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <boost/beast/http/string_body.hpp>

#include "base/logging.h"

#include "util/cloud/gcp/gcp_creds_provider.h"
#include "util/http/http_client.h"

Expand Down Expand Up @@ -72,38 +71,48 @@ RobustSender::RobustSender(http::ClientPool* pool, GCPCredsProvider* provider)
: pool_(pool), provider_(provider) {
}

auto RobustSender::Send(unsigned num_iterations,
detail::HttpRequestBase* req) -> io::Result<HeaderParserPtr> {
error_code RobustSender::Send(unsigned num_iterations, detail::HttpRequestBase* req,
SenderResult* result) {
error_code ec;
for (unsigned i = 0; i < num_iterations; ++i) { // Iterate for possible token refresh.
auto res = pool_->GetHandle();
if (!res)
return nonstd::make_unexpected(res.error());

auto client_handle = std::move(res.value());
return res.error();

result->client_handle = std::move(res.value());
auto* client_handle = result->client_handle.get();
VLOG(1) << "HttpReq " << client_handle->host() << ": " << req->GetHeaders() << ", ["
<< client_handle->native_handle() << "]";

RETURN_UNEXPECTED(req->Send(client_handle.get()));
HeaderParserPtr parser(new h2::response_parser<h2::empty_body>());
RETURN_UNEXPECTED(client_handle->ReadHeader(parser.get()));
RETURN_ERROR(req->Send(client_handle));
result->eb_parser.reset(new h2::response_parser<h2::empty_body>());

// no limit. Prevent from this parser to throw an error due to large body.
// result->eb_parser->body_limit(boost::optional<uint64_t>());
auto header_err = client_handle->ReadHeader(result->eb_parser.get());

// Unfortunately earlier versions of boost (1.74-) have a bug that do not support the body_limit
// directive above. Therefore, we fix it here.
if (header_err == h2::error::body_limit) {
header_err.clear();
}
RETURN_ERROR(header_err);
{
const auto& msg = parser->get();
const auto& msg = result->eb_parser->get();
VLOG(1) << "RespHeader" << i << ": " << msg;

if (!parser->keep_alive()) {
if (!result->eb_parser->keep_alive()) {
LOG(FATAL) << "TBD: Schedule reconnect due to conn-close header";
}

if (IsResponseOK(msg.result())) {
return parser;
return {};
}
}

// We have some kind of error, possibly with body that needs to be drained.
h2::response_parser<h2::string_body> drainer(std::move(*parser));
RETURN_UNEXPECTED(client_handle->Recv(&drainer));
h2::response_parser<h2::string_body> drainer(std::move(*result->eb_parser));
RETURN_ERROR(client_handle->Recv(&drainer));
const auto& msg = drainer.get();

if (DoesServerPushback(msg.result())) {
Expand All @@ -115,21 +124,25 @@ auto RobustSender::Send(unsigned num_iterations,

if (IsUnauthorized(msg)) {
VLOG(1) << "Refreshing token";
RETURN_UNEXPECTED(provider_->RefreshToken(client_handle->proactor()));
RETURN_ERROR(provider_->RefreshToken(client_handle->proactor()));
req->SetHeader(h2::field::authorization, AuthHeader(provider_->access_token()));

continue;
}

if (msg.result() == h2::status::forbidden) {
return nonstd::make_unexpected(make_error_code(errc::operation_not_permitted));
return make_error_code(errc::operation_not_permitted);
}

if (msg.result() == h2::status::not_found) {
return make_error_code(errc::no_such_file_or_directory);
}

ec = make_error_code(errc::bad_message);
LOG(DFATAL) << "Unexpected response " << msg << "\n" << msg.body() << "\n";
}

return nonstd::make_unexpected(ec);
return ec;
}

} // namespace util::cloud
18 changes: 15 additions & 3 deletions util/cloud/gcp/gcp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EmptyRequestImpl : public HttpRequestBase {
req_.set(f, boost::string_view{value.data(), value.size()});
}

// Request headers.
const boost::beast::http::header<true>& GetHeaders() const final {
return req_.base();
}
Expand Down Expand Up @@ -99,12 +100,14 @@ class RobustSender {
RobustSender& operator=(const RobustSender&) = delete;

public:
using HeaderParserPtr =
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>>;
struct SenderResult {
std::unique_ptr<boost::beast::http::response_parser<boost::beast::http::empty_body>> eb_parser;
http::ClientPool::ClientHandle client_handle;
};

RobustSender(http::ClientPool* pool, GCPCredsProvider* provider);

io::Result<HeaderParserPtr> Send(unsigned num_iterations, detail::HttpRequestBase* req);
std::error_code Send(unsigned num_iterations, detail::HttpRequestBase* req, SenderResult* result);

private:
http::ClientPool* pool_;
Expand All @@ -122,4 +125,13 @@ std::string AuthHeader(std::string_view access_token);
} \
} while (false)

#define RETURN_ERROR(x) \
do { \
auto ec = (x); \
if (ec) { \
VLOG(1) << "Error calling " << #x << ": " << ec.message(); \
return ec; \
} \
} while (false)

} // namespace util::cloud
49 changes: 18 additions & 31 deletions util/cloud/gcp/gcs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ auto Unexpected(std::errc code) {

const char kInstanceTokenUrl[] = "/computeMetadata/v1/instance/service-accounts/default/token";

#define RETURN_ERROR(x) \
do { \
auto ec = (x); \
if (ec) { \
VLOG(1) << "Error calling " << #x << ": " << ec.message(); \
return ec; \
} \
} while (false)

io::Result<string> ExpandFilePath(string_view path) {
io::Result<io::StatShortVec> res = io::StatFiles(path);
Expand Down Expand Up @@ -351,14 +343,18 @@ error_code GCPCredsProvider::RefreshToken(fb2::ProactorBase* pb) {
return {};
}

GCS::GCS(GCPCredsProvider* provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb)
: creds_provider_(*provider), ssl_ctx_(ssl_cntx) {
client_pool_.reset(new http::ClientPool(GCS_API_DOMAIN, ssl_ctx_, pb));
client_pool_->SetOnConnect([](int fd) {
unique_ptr<http::ClientPool> GCS::CreateApiConnectionPool(SSL_CTX* ssl_ctx, fb2::ProactorBase* pb) {
unique_ptr<http::ClientPool> res(new http::ClientPool(GCS_API_DOMAIN, ssl_ctx, pb));
res->SetOnConnect([](int fd) {
auto ec = EnableKeepAlive(fd);
LOG_IF(WARNING, ec) << "Error setting keep alive " << ec.message() << " " << fd;
});
return res;
}

GCS::GCS(GCPCredsProvider* provider, SSL_CTX* ssl_cntx, fb2::ProactorBase* pb)
: creds_provider_(*provider), ssl_ctx_(ssl_cntx) {
client_pool_ = CreateApiConnectionPool(ssl_ctx_, pb);
// TODO: to make it configurable.
client_pool_->set_connect_timeout(2000);
}
Expand All @@ -377,15 +373,11 @@ error_code GCS::ListBuckets(ListBucketCb cb) {
RobustSender sender(client_pool_.get(), &creds_provider_);

while (true) {
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(2, &empty_req);
if (!parse_res)
return parse_res.error();
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);
h2::response_parser<h2::string_body> resp(std::move(*empty_parser));
auto res = client_pool_->GetHandle();
if (!res)
return res.error();
auto client = std::move(*res);
RobustSender::SenderResult result;
RETURN_ERROR(sender.Send(2, &empty_req, &result));

h2::response_parser<h2::string_body> resp(std::move(*result.eb_parser));
auto client = std::move(result.client_handle);

RETURN_ERROR(client->Recv(&resp));

Expand Down Expand Up @@ -437,16 +429,11 @@ error_code GCS::List(string_view bucket, string_view prefix, bool recursive, Lis
rj::Document doc;
RobustSender sender(client_pool_.get(), &creds_provider_);
while (true) {
io::Result<RobustSender::HeaderParserPtr> parse_res = sender.Send(2, &empty_req);
if (!parse_res)
return parse_res.error();
RobustSender::HeaderParserPtr empty_parser = std::move(*parse_res);
h2::response_parser<h2::string_body> resp(std::move(*empty_parser));

auto res = client_pool_->GetHandle();
if (!res)
return res.error();
auto client = std::move(*res);
RobustSender::SenderResult parse_res;
RETURN_ERROR(sender.Send(2, &empty_req, &parse_res));

h2::response_parser<h2::string_body> resp(std::move(*parse_res.eb_parser));
auto client = std::move(parse_res.client_handle);

RETURN_ERROR(client->Recv(&resp));

Expand Down
7 changes: 6 additions & 1 deletion util/cloud/gcp/gcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,12 @@ class GCS {
std::error_code List(std::string_view bucket, std::string_view prefix, bool recursive,
ListObjectCb cb);

http::ClientPool* GetConnectionPool() { return client_pool_.get(); }
http::ClientPool* GetConnectionPool() {
return client_pool_.get();
}

static std::unique_ptr<http::ClientPool> CreateApiConnectionPool(SSL_CTX* ssl_ctx,
fb2::ProactorBase* pb);

private:
GCPCredsProvider& creds_provider_;
Expand Down
Loading

0 comments on commit f102aa0

Please sign in to comment.