Skip to content

Commit

Permalink
Multithreaded error handling in storage (#524)
Browse files Browse the repository at this point in the history
If a thread in the get implementation throws an error, the storage
currently dies because C++ does not propagate exceptions to the caller.
With this change, we're able to return a gRPC error instead. Also, we
throw now in case the file path is empty.
  • Loading branch information
MaxiBoether authored Jun 19, 2024
1 parent 0a511bb commit 601c2e1
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
5 changes: 4 additions & 1 deletion modyn/evaluator/internal/pytorch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def _prepare_dataloader(self, evaluation_info: EvaluationInfo) -> torch.utils.da
)
self._debug("Creating DataLoader.")
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=evaluation_info.batch_size, num_workers=evaluation_info.num_dataloaders, timeout=60
dataset,
batch_size=evaluation_info.batch_size,
num_workers=evaluation_info.num_dataloaders,
timeout=60 if evaluation_info.num_dataloaders > 0 else 0,
)

return dataloader
Expand Down
52 changes: 45 additions & 7 deletions modyn/storage/include/internal/grpc/storage_service_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#include <spdlog/spdlog.h>
#include <yaml-cpp/yaml.h>

#include <algorithm>
#include <cctype>
#include <exception>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
#include <variant>
Expand Down Expand Up @@ -319,16 +323,27 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
get_samples_and_send<WriterT>(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_);

} else {
std::vector<std::exception_ptr> thread_exceptions(retrieval_threads_);
std::mutex exception_mutex;
std::vector<std::pair<std::vector<int64_t>::const_iterator, std::vector<int64_t>::const_iterator>>
its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_);
std::vector<std::thread> retrieval_threads_vector(retrieval_threads_);
for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) {
const std::vector<int64_t>::const_iterator begin = its_per_thread[thread_id].first;
const std::vector<int64_t>::const_iterator end = its_per_thread[thread_id].second;

retrieval_threads_vector[thread_id] =
std::thread(StorageServiceImpl::get_samples_and_send<WriterT>, begin, end, writer, &writer_mutex,
&dataset_data, &config_, sample_batch_size_);
retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data,
&thread_exceptions, &exception_mutex, this]() {
try {
get_samples_and_send<WriterT>(begin, end, writer, &writer_mutex, &dataset_data, &config_,
sample_batch_size_);
} catch (const std::exception& e) {
const std::lock_guard<std::mutex> lock(exception_mutex);
spdlog::error(
fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what()));
thread_exceptions[thread_id] = std::current_exception();
}
});
}

for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) {
Expand All @@ -337,6 +352,17 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
}
}
retrieval_threads_vector.clear();
// In order for the gRPC call to return an error, we need to rethrow the threaded exceptions.
for (auto& e_ptr : thread_exceptions) {
if (e_ptr) {
try {
std::rethrow_exception(e_ptr);
} catch (const std::exception& e) {
SPDLOG_ERROR("Error while unwinding thread: {}\nPropagating it up the call chain.", e.what());
throw;
}
}
}
}
}

Expand Down Expand Up @@ -529,6 +555,12 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
// keys than this
try {
const uint64_t num_keys = sample_keys.size();

if (num_keys == 0) {
SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file");
return;
}

std::vector<int64_t> sample_labels(num_keys);
std::vector<uint64_t> sample_indices(num_keys);
std::vector<int64_t> sample_fileids(num_keys);
Expand All @@ -539,15 +571,16 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
session << sample_query, soci::into(sample_labels), soci::into(sample_indices), soci::into(sample_fileids),
soci::use(dataset_data.dataset_id);

int64_t current_file_id = sample_fileids[0];
int64_t current_file_id = sample_fileids.at(0);
uint64_t current_file_start_idx = 0;
std::string current_file_path;
session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id",
soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id);

if (current_file_path.empty()) {
SPDLOG_ERROR(fmt::format("Could not obtain full path of file id {} in dataset {}", current_file_id,
dataset_data.dataset_id));
if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) {
SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query));
throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}",
current_file_id, dataset_data.dataset_id));
}
const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config);
auto filesystem_wrapper =
Expand Down Expand Up @@ -594,6 +627,11 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
current_file_path = "",
session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id",
soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id);
if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) {
SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query));
throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}",
current_file_id, dataset_data.dataset_id));
}
file_wrapper->set_file_path(current_file_path);
current_file_start_idx = sample_idx;
}
Expand Down
6 changes: 5 additions & 1 deletion modyn/trainer_server/internal/dataset/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def prepare_dataloaders(
)
logger.debug("Creating DataLoader.")
train_dataloader = torch.utils.data.DataLoader(
train_set, batch_size=batch_size, num_workers=num_dataloaders, drop_last=drop_last, timeout=60
train_set,
batch_size=batch_size,
num_workers=num_dataloaders,
drop_last=drop_last,
timeout=60 if num_dataloaders > 0 else 0,
)

# TODO(#50): what to do with the val set in the general case?
Expand Down

0 comments on commit 601c2e1

Please sign in to comment.