Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flush periodically and make shared lib #23

Merged
merged 2 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ file(GLOB protos "proto/*.proto")

protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${protos})

add_library(tensorboard_logger STATIC
add_library(tensorboard_logger SHARED
"src/crc.cc"
"src/tensorboard_logger.cc"
${PROTO_SRCS}
Expand Down
48 changes: 46 additions & 2 deletions include/tensorboard_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include <fstream>
#include <string>
#include <vector>
#include <atomic>
#include <thread>
#include <mutex>

#include "crc.h"
#include "event.pb.h"
Expand All @@ -20,10 +23,37 @@ const std::string kProjectorConfigFile = "projector_config.pbtxt";
const std::string kProjectorPluginName = "projector";
const std::string kTextPluginName = "text";


struct TensorBoardLoggerOptions
{
// Log is flushed whenever this many entries have been written since the last
// forced flush.
size_t max_queue_size_ = 100000;
TensorBoardLoggerOptions &max_queue_size(size_t max_queue_size) {
max_queue_size_ = max_queue_size;
return *this;
}

// Log is flushed with this period.
size_t flush_period_s_ = 60;
TensorBoardLoggerOptions &flush_period_s(size_t flush_period_s) {
flush_period_s_ = flush_period_s;
return *this;
}

bool resume_ = false;
TensorBoardLoggerOptions &resume(bool resume) {
resume_ = resume;
return *this;
}
};

class TensorBoardLogger {
public:

explicit TensorBoardLogger(const std::string &log_file,
bool resume = false) {
const TensorBoardLoggerOptions &options={}) {
this->options = options;
auto basename = get_basename(log_file);
if (basename.find("tfevents") == std::string::npos) {
throw std::runtime_error(
Expand All @@ -33,19 +63,26 @@ class TensorBoardLogger {
bucket_limits_ = nullptr;
ofs_ = new std::ofstream(
log_file, std::ios::out |
(resume ? std::ios::app : std::ios::trunc) |
(options.resume_ ? std::ios::app : std::ios::trunc) |
std::ios::binary);
if (!ofs_->is_open()) {
throw std::runtime_error("failed to open log_file " + log_file);
}
log_dir_ = get_parent_dir(log_file);

flushing_thread = std::thread(&TensorBoardLogger::flusher, this);
}
~TensorBoardLogger() {
ofs_->close();
if (bucket_limits_ != nullptr) {
delete bucket_limits_;
bucket_limits_ = nullptr;
}

stop = true;
if (flushing_thread.joinable()) {
flushing_thread.join();
}
}
int add_scalar(const std::string &tag, int step, double value);
int add_scalar(const std::string &tag, int step, float value);
Expand Down Expand Up @@ -153,10 +190,17 @@ class TensorBoardLogger {
int generate_default_buckets();
int add_event(int64_t step, Summary *summary);
int write(Event &event);
void flusher();

std::string log_dir_;
std::ofstream *ofs_;
std::vector<double> *bucket_limits_;
TensorBoardLoggerOptions options;

std::atomic<bool> stop{false};
size_t queue_size{0};
std::thread flushing_thread;
std::mutex file_object_mtx{};
}; // class TensorBoardLogger

#endif // TENSORBOARD_LOGGER_H
27 changes: 26 additions & 1 deletion src/tensorboard_logger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,24 @@ int TensorBoardLogger::add_images(
return add_event(step, summary);
}

void TensorBoardLogger::flusher()
{
auto period = std::chrono::seconds(options.flush_period_s_);
auto next_flush_time = std::chrono::high_resolution_clock::now() + period;

while (!stop)
{
if (std::chrono::high_resolution_clock::now() < next_flush_time) {
std::this_thread::sleep_for(std::chrono::seconds(1));
continue;
}

std::lock_guard<std::mutex> lock{file_object_mtx};
ofs_->flush();
next_flush_time = std::chrono::high_resolution_clock::now() + period;
}
}

int TensorBoardLogger::add_audio(const string &tag, int step,
const string &encoded_audio, float sample_rate,
int num_channels, int length_frame,
Expand Down Expand Up @@ -301,11 +319,18 @@ int TensorBoardLogger::write(Event &event) {
masked_crc32c((char *)&buf_len, sizeof(buf_len)); // NOLINT
uint32_t data_crc = masked_crc32c(buf.c_str(), buf.size());

std::lock_guard<std::mutex> lock{file_object_mtx};

ofs_->write((char *)&buf_len, sizeof(buf_len)); // NOLINT
ofs_->write((char *)&len_crc, sizeof(len_crc)); // NOLINT
ofs_->write(buf.c_str(), buf.size());
ofs_->write((char *)&data_crc, sizeof(data_crc)); // NOLINT
ofs_->flush();

if (queue_size++ > options.max_queue_size_) {
ofs_->flush();
queue_size = 0;
}

return 0;
}

Expand Down