Skip to content

Commit

Permalink
Merge branch 'main' into robinholzi/fix/yearbook-generation-splits
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxiBoether authored Jun 19, 2024
2 parents a23a9e5 + 29018c7 commit 3a2dfd4
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 85 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- types-protobuf==5.26.*
- evidently==0.4.27
- alibi-detect==0.12.*
- tenacity
- jsonschema
- psycopg2
- sqlalchemy>=2.0
Expand Down
84 changes: 67 additions & 17 deletions modyn/evaluator/internal/dataset/evaluation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
grpc_connection_established,
instantiate_class,
)
from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential
from torch.utils.data import IterableDataset, get_worker_info
from torchvision import transforms

Expand Down Expand Up @@ -78,6 +79,13 @@ def _setup_composed_transform(self) -> None:
if len(self._transform_list) > 0:
self._transform = transforms.Compose(self._transform_list)

@retry(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=1, min=2, max=60),
before=before_log(logger, logging.ERROR),
after=after_log(logger, logging.ERROR),
reraise=True,
)
def _init_grpc(self) -> None:
storage_channel = grpc.insecure_channel(
self._storage_address,
Expand All @@ -103,22 +111,64 @@ def _silence_pil() -> None: # pragma: no cover

def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable[list[int]]:
self._info("Getting keys from storage", worker_id)
req_keys = GetDataPerWorkerRequest(
dataset_id=self._dataset_id,
worker_id=worker_id,
total_workers=total_workers,
start_timestamp=self._start_timestamp,
end_timestamp=self._end_timestamp,
)
resp_keys: GetDataPerWorkerResponse
for resp_keys in self._storagestub.GetDataPerWorker(req_keys):
yield resp_keys.keys

def _get_data_from_storage(self, keys: list[int]) -> Iterable[list[tuple[int, bytes, int]]]:
request = GetRequest(dataset_id=self._dataset_id, keys=keys)
response: GetResponse
for response in self._storagestub.Get(request):
yield list(zip(response.keys, response.samples, response.labels))
last_processed_index = -1
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
req_keys = GetDataPerWorkerRequest(
dataset_id=self._dataset_id,
worker_id=worker_id,
total_workers=total_workers,
start_timestamp=self._start_timestamp,
end_timestamp=self._end_timestamp,
)
resp_keys: GetDataPerWorkerResponse
for index, resp_keys in enumerate(self._storagestub.GetDataPerWorker(req_keys)):
if index <= last_processed_index:
continue # Skip already processed responses
yield resp_keys.keys
last_processed_index = index

except grpc.RpcError as e:
self._info(
"gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
self._info(
f"Error occured while asking {self._dataset_id} for worker data:\n{worker_id}", worker_id
)
self._init_grpc()
raise e

def _get_data_from_storage(
self, keys: list[int], worker_id: Optional[int] = None
) -> Iterable[list[tuple[int, bytes, int]]]:
last_processed_index = -1
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
request = GetRequest(dataset_id=self._dataset_id, keys=keys)
response: GetResponse
for index, response in enumerate(self._storagestub.Get(request)):
if index <= last_processed_index:
continue # Skip already processed responses
yield list(zip(response.keys, response.samples, response.labels))
last_processed_index = index

except grpc.RpcError as e: # We catch and reraise to log and reconnect
self._info(
f"gRPC error occurred, last index = {last_processed_index}: {e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
self._info(f"Error occured while asking {self._dataset_id} for keys:\n{keys}", worker_id)
self._init_grpc()
raise e

def __iter__(self) -> Generator:
worker_info = get_worker_info()
Expand All @@ -144,6 +194,6 @@ def __iter__(self) -> Generator:

# TODO(#175): we might want to do/accelerate prefetching here.
for keys in self._get_keys_from_storage(worker_id, total_workers):
for data in self._get_data_from_storage(keys):
for data in self._get_data_from_storage(keys, worker_id):
for key, sample, label in data:
yield key, self._transform(sample), label
5 changes: 4 additions & 1 deletion modyn/evaluator/internal/pytorch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def _prepare_dataloader(self, evaluation_info: EvaluationInfo) -> torch.utils.da
)
self._debug("Creating DataLoader.")
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=evaluation_info.batch_size, num_workers=evaluation_info.num_dataloaders, timeout=60
dataset,
batch_size=evaluation_info.batch_size,
num_workers=evaluation_info.num_dataloaders,
timeout=60 if evaluation_info.num_dataloaders > 0 else 0,
)

return dataloader
Expand Down
52 changes: 45 additions & 7 deletions modyn/storage/include/internal/grpc/storage_service_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
#include <spdlog/spdlog.h>
#include <yaml-cpp/yaml.h>

#include <algorithm>
#include <cctype>
#include <exception>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
#include <variant>
Expand Down Expand Up @@ -319,16 +323,27 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
get_samples_and_send<WriterT>(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_);

} else {
std::vector<std::exception_ptr> thread_exceptions(retrieval_threads_);
std::mutex exception_mutex;
std::vector<std::pair<std::vector<int64_t>::const_iterator, std::vector<int64_t>::const_iterator>>
its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_);
std::vector<std::thread> retrieval_threads_vector(retrieval_threads_);
for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) {
const std::vector<int64_t>::const_iterator begin = its_per_thread[thread_id].first;
const std::vector<int64_t>::const_iterator end = its_per_thread[thread_id].second;

retrieval_threads_vector[thread_id] =
std::thread(StorageServiceImpl::get_samples_and_send<WriterT>, begin, end, writer, &writer_mutex,
&dataset_data, &config_, sample_batch_size_);
retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data,
&thread_exceptions, &exception_mutex, this]() {
try {
get_samples_and_send<WriterT>(begin, end, writer, &writer_mutex, &dataset_data, &config_,
sample_batch_size_);
} catch (const std::exception& e) {
const std::lock_guard<std::mutex> lock(exception_mutex);
spdlog::error(
fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what()));
thread_exceptions[thread_id] = std::current_exception();
}
});
}

for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) {
Expand All @@ -337,6 +352,17 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
}
}
retrieval_threads_vector.clear();
// In order for the gRPC call to return an error, we need to rethrow the threaded exceptions.
for (auto& e_ptr : thread_exceptions) {
if (e_ptr) {
try {
std::rethrow_exception(e_ptr);
} catch (const std::exception& e) {
SPDLOG_ERROR("Error while unwinding thread: {}\nPropagating it up the call chain.", e.what());
throw;
}
}
}
}
}

Expand Down Expand Up @@ -529,6 +555,12 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
// keys than this
try {
const uint64_t num_keys = sample_keys.size();

if (num_keys == 0) {
SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file");
return;
}

std::vector<int64_t> sample_labels(num_keys);
std::vector<uint64_t> sample_indices(num_keys);
std::vector<int64_t> sample_fileids(num_keys);
Expand All @@ -539,15 +571,16 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
session << sample_query, soci::into(sample_labels), soci::into(sample_indices), soci::into(sample_fileids),
soci::use(dataset_data.dataset_id);

int64_t current_file_id = sample_fileids[0];
int64_t current_file_id = sample_fileids.at(0);
uint64_t current_file_start_idx = 0;
std::string current_file_path;
session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id",
soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id);

if (current_file_path.empty()) {
SPDLOG_ERROR(fmt::format("Could not obtain full path of file id {} in dataset {}", current_file_id,
dataset_data.dataset_id));
if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) {
SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query));
throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}",
current_file_id, dataset_data.dataset_id));
}
const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config);
auto filesystem_wrapper =
Expand Down Expand Up @@ -594,6 +627,11 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service {
current_file_path = "",
session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id",
soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id);
if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) {
SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query));
throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}",
current_file_id, dataset_data.dataset_id));
}
file_wrapper->set_file_path(current_file_path);
current_file_start_idx = sample_idx;
}
Expand Down
15 changes: 14 additions & 1 deletion modyn/supervisor/internal/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from modyn.supervisor.internal.eval.result_writer import AbstractEvaluationResultWriter
from modyn.supervisor.internal.utils import EvaluationStatusReporter
from modyn.utils import grpc_common_config, grpc_connection_established
from tenacity import Retrying, stop_after_attempt, wait_random_exponential

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,6 +271,7 @@ def prepare_evaluation_request(

return EvaluateModelRequest(**start_evaluation_kwargs)

# pylint: disable=too-many-branches
def wait_for_evaluation_completion(
self, training_id: int, evaluations: dict[int, EvaluationStatusReporter]
) -> None:
Expand All @@ -292,7 +294,18 @@ def wait_for_evaluation_completion(
current_evaluation_id = working_queue.popleft()
current_evaluation_reporter = evaluations[current_evaluation_id]
req = EvaluationStatusRequest(evaluation_id=current_evaluation_id)
res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req)

for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req)
except grpc.RpcError as e: # We catch and reraise to easily reconnect
logger.error(e)
logger.error(f"[Training {training_id}]: gRPC connection error, trying to reconnect.")
self.init_evaluator()
raise e

if not res.valid:
exception_msg = f"Evaluation {current_evaluation_id} is invalid at server:\n{res}\n"
Expand Down
19 changes: 16 additions & 3 deletions modyn/supervisor/internal/pipeline_executor/evaluation_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from multiprocessing import Queue
from pathlib import Path

import grpc
import pandas as pd
from modyn.config.schema.pipeline import ModynPipelineConfig
from modyn.config.schema.system import ModynConfig
Expand All @@ -31,6 +32,7 @@
from modyn.supervisor.internal.utils.evaluation_status_reporter import EvaluationStatusReporter
from modyn.utils.utils import current_time_micros, dynamic_module_import
from pydantic import BaseModel
from tenacity import Retrying, stop_after_attempt, wait_random_exponential

eval_strategy_module = dynamic_module_import("modyn.supervisor.internal.eval.strategies")

Expand All @@ -56,13 +58,13 @@ def __init__(
pipeline_logdir: Path,
config: ModynConfig,
pipeline: ModynPipelineConfig,
grpc: GRPCHandler,
grpc_handler: GRPCHandler,
):
self.pipeline_id = pipeline_id
self.pipeline_logdir = pipeline_logdir
self.config = config
self.pipeline = pipeline
self.grpc = grpc
self.grpc = grpc_handler
self.context: AfterPipelineEvalContext | None = None
self.eval_handlers = (
[EvalHandler(eval_handler_config) for eval_handler_config in pipeline.evaluation.handlers]
Expand Down Expand Up @@ -272,7 +274,18 @@ def _single_evaluation(self, log: StageLog, eval_status_queue: Queue, eval_req:
eval_req.interval_start,
eval_req.interval_end,
)
response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request)
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request)
except grpc.RpcError as e: # We catch and reraise to reconnect
logger.error(e)
logger.error("gRPC connection error, trying to reconnect...")
self.grpc.init_evaluator()
raise e

if not response.evaluation_started:
log.info.failure_reason = EvaluationAbortedReason.DESCRIPTOR.values_by_number[
response.eval_aborted_reason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def evaluation_executor(
pipeline_logdir=tmp_dir_tests,
config=dummy_system_config,
pipeline=pipeline_config,
grpc=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)),
grpc_handler=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)),
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=unused-argument, no-name-in-module
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import grpc
import pytest
Expand All @@ -10,6 +10,7 @@
UsesWeightsResponse,
)
from modyn.trainer_server.internal.dataset.key_sources import SelectorKeySource
from tenacity import RetryCallState


def test_init():
Expand Down Expand Up @@ -138,3 +139,28 @@ def test_unweighted_key_source(test_grp, test_connection):
keys, weights = keysource.get_keys_and_weights(0, 0)
assert weights == [-1.0, -2.0, -3.0]
assert keys == [1, 2, 3]


def test_retry_reconnection_callback():
pipeline_id = 12
trigger_id = 1
selector_address = "localhost:1234"
keysource = SelectorKeySource(pipeline_id, trigger_id, selector_address)

# Create a mock RetryCallState
mock_retry_state = MagicMock(spec=RetryCallState)
mock_retry_state.attempt_number = 3
mock_retry_state.outcome = MagicMock()
mock_retry_state.outcome.failed = True
mock_retry_state.args = [keysource]

# Mock the _connect_to_selector method to raise an exception
with patch.object(
keysource, "_connect_to_selector", side_effect=ConnectionError("Connection failed")
) as mock_method:
# Call the retry_reconnection_callback with the mock state
with pytest.raises(ConnectionError):
SelectorKeySource.retry_reconnection_callback(mock_retry_state)

# Check that the method tried to reconnect
mock_method.assert_called()
Loading

0 comments on commit 3a2dfd4

Please sign in to comment.