From 0a511bb904296d9e1c0faba233bda3a3e51d483a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= <2116466+MaxiBoether@users.noreply.github.com> Date: Tue, 18 Jun 2024 22:07:57 +0200 Subject: [PATCH 1/3] More robust gRPC connections, using `tenacity` where possible (#521) Sometimes, we face outages/random disconnections during training. This fixes it in places where I encountered it last night. I tried to integrate `tenacity` as suggested by @robinholzi, but it's not always possible since the retry logic involves keeping track of already done work, which I don't want to put into class state Part 6/n of porting over SIGMOD changes. --- environment.yml | 1 + .../internal/dataset/evaluation_dataset.py | 84 +++++++++++++++---- modyn/supervisor/internal/grpc_handler.py | 15 +++- .../pipeline_executor/evaluation_executor.py | 19 ++++- .../test_evaluation_executor.py | 2 +- .../key_sources/test_selector_key_source.py | 28 ++++++- .../internal/data/test_online_dataset.py | 53 +++++++++++- .../dataset/key_sources/local_key_source.py | 4 +- .../key_sources/selector_key_source.py | 54 +++++++----- .../internal/dataset/online_dataset.py | 75 ++++++++++------- .../internal/trainer/pytorch_trainer.py | 1 + 11 files changed, 261 insertions(+), 75 deletions(-) diff --git a/environment.yml b/environment.yml index 08a511f6e..5de4819b7 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - types-protobuf==5.26.* - evidently==0.4.27 - alibi-detect==0.12.* + - tenacity - jsonschema - psycopg2 - sqlalchemy>=2.0 diff --git a/modyn/evaluator/internal/dataset/evaluation_dataset.py b/modyn/evaluator/internal/dataset/evaluation_dataset.py index b6c703491..b679f82c5 100644 --- a/modyn/evaluator/internal/dataset/evaluation_dataset.py +++ b/modyn/evaluator/internal/dataset/evaluation_dataset.py @@ -16,6 +16,7 @@ grpc_connection_established, instantiate_class, ) +from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -78,6 +79,13 @@ def _setup_composed_transform(self) -> None: if len(self._transform_list) > 0: self._transform = transforms.Compose(self._transform_list) + @retry( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + before=before_log(logger, logging.ERROR), + after=after_log(logger, logging.ERROR), + reraise=True, + ) def _init_grpc(self) -> None: storage_channel = grpc.insecure_channel( self._storage_address, @@ -103,22 +111,64 @@ def _silence_pil() -> None: # pragma: no cover def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable[list[int]]: self._info("Getting keys from storage", worker_id) - req_keys = GetDataPerWorkerRequest( - dataset_id=self._dataset_id, - worker_id=worker_id, - total_workers=total_workers, - start_timestamp=self._start_timestamp, - end_timestamp=self._end_timestamp, - ) - resp_keys: GetDataPerWorkerResponse - for resp_keys in self._storagestub.GetDataPerWorker(req_keys): - yield resp_keys.keys - - def _get_data_from_storage(self, keys: list[int]) -> Iterable[list[tuple[int, bytes, int]]]: - request = GetRequest(dataset_id=self._dataset_id, keys=keys) - response: GetResponse - for response in self._storagestub.Get(request): - yield list(zip(response.keys, response.samples, response.labels)) + last_processed_index = -1 + for attempt in Retrying( + stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + ): + with attempt: + try: + req_keys = GetDataPerWorkerRequest( + dataset_id=self._dataset_id, + worker_id=worker_id, + total_workers=total_workers, + start_timestamp=self._start_timestamp, + end_timestamp=self._end_timestamp, + ) + resp_keys: GetDataPerWorkerResponse + for index, resp_keys in enumerate(self._storagestub.GetDataPerWorker(req_keys)): + if index <= last_processed_index: + continue # Skip already processed responses + yield resp_keys.keys + last_processed_index = index + + except grpc.RpcError as e: + self._info( + "gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}", + worker_id, + ) + self._info(f"Stringified exception: {str(e)}", worker_id) + self._info( + f"Error occured while asking {self._dataset_id} for worker data:\n{worker_id}", worker_id + ) + self._init_grpc() + raise e + + def _get_data_from_storage( + self, keys: list[int], worker_id: Optional[int] = None + ) -> Iterable[list[tuple[int, bytes, int]]]: + last_processed_index = -1 + for attempt in Retrying( + stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + ): + with attempt: + try: + request = GetRequest(dataset_id=self._dataset_id, keys=keys) + response: GetResponse + for index, response in enumerate(self._storagestub.Get(request)): + if index <= last_processed_index: + continue # Skip already processed responses + yield list(zip(response.keys, response.samples, response.labels)) + last_processed_index = index + + except grpc.RpcError as e: # We catch and reraise to log and reconnect + self._info( + f"gRPC error occurred, last index = {last_processed_index}: {e.code()} - {e.details()}", + worker_id, + ) + self._info(f"Stringified exception: {str(e)}", worker_id) + self._info(f"Error occured while asking {self._dataset_id} for keys:\n{keys}", worker_id) + self._init_grpc() + raise e def __iter__(self) -> Generator: worker_info = get_worker_info() @@ -144,6 +194,6 @@ def __iter__(self) -> Generator: # TODO(#175): we might want to do/accelerate prefetching here. for keys in self._get_keys_from_storage(worker_id, total_workers): - for data in self._get_data_from_storage(keys): + for data in self._get_data_from_storage(keys, worker_id): for key, sample, label in data: yield key, self._transform(sample), label diff --git a/modyn/supervisor/internal/grpc_handler.py b/modyn/supervisor/internal/grpc_handler.py index 176acf514..d22743a35 100644 --- a/modyn/supervisor/internal/grpc_handler.py +++ b/modyn/supervisor/internal/grpc_handler.py @@ -47,6 +47,7 @@ from modyn.supervisor.internal.eval.result_writer import AbstractEvaluationResultWriter from modyn.supervisor.internal.utils import EvaluationStatusReporter from modyn.utils import grpc_common_config, grpc_connection_established +from tenacity import Retrying, stop_after_attempt, wait_random_exponential logger = logging.getLogger(__name__) @@ -270,6 +271,7 @@ def prepare_evaluation_request( return EvaluateModelRequest(**start_evaluation_kwargs) + # pylint: disable=too-many-branches def wait_for_evaluation_completion( self, training_id: int, evaluations: dict[int, EvaluationStatusReporter] ) -> None: @@ -292,7 +294,18 @@ def wait_for_evaluation_completion( current_evaluation_id = working_queue.popleft() current_evaluation_reporter = evaluations[current_evaluation_id] req = EvaluationStatusRequest(evaluation_id=current_evaluation_id) - res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req) + + for attempt in Retrying( + stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + ): + with attempt: + try: + res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req) + except grpc.RpcError as e: # We catch and reraise to easily reconnect + logger.error(e) + logger.error(f"[Training {training_id}]: gRPC connection error, trying to reconnect.") + self.init_evaluator() + raise e if not res.valid: exception_msg = f"Evaluation {current_evaluation_id} is invalid at server:\n{res}\n" diff --git a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py index ac702b681..96e647f04 100644 --- a/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/evaluation_executor.py @@ -10,6 +10,7 @@ from multiprocessing import Queue from pathlib import Path +import grpc import pandas as pd from modyn.config.schema.pipeline import ModynPipelineConfig from modyn.config.schema.system import ModynConfig @@ -31,6 +32,7 @@ from modyn.supervisor.internal.utils.evaluation_status_reporter import EvaluationStatusReporter from modyn.utils.utils import current_time_micros, dynamic_module_import from pydantic import BaseModel +from tenacity import Retrying, stop_after_attempt, wait_random_exponential eval_strategy_module = dynamic_module_import("modyn.supervisor.internal.eval.strategies") @@ -56,13 +58,13 @@ def __init__( pipeline_logdir: Path, config: ModynConfig, pipeline: ModynPipelineConfig, - grpc: GRPCHandler, + grpc_handler: GRPCHandler, ): self.pipeline_id = pipeline_id self.pipeline_logdir = pipeline_logdir self.config = config self.pipeline = pipeline - self.grpc = grpc + self.grpc = grpc_handler self.context: AfterPipelineEvalContext | None = None self.eval_handlers = ( [EvalHandler(eval_handler_config) for eval_handler_config in pipeline.evaluation.handlers] @@ -272,7 +274,18 @@ def _single_evaluation(self, log: StageLog, eval_status_queue: Queue, eval_req: eval_req.interval_start, eval_req.interval_end, ) - response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request) + for attempt in Retrying( + stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + ): + with attempt: + try: + response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request) + except grpc.RpcError as e: # We catch and reraise to reconnect + logger.error(e) + logger.error("gRPC connection error, trying to reconnect...") + self.grpc.init_evaluator() + raise e + if not response.evaluation_started: log.info.failure_reason = EvaluationAbortedReason.DESCRIPTOR.values_by_number[ response.eval_aborted_reason diff --git a/modyn/tests/supervisor/internal/pipeline_executor/test_evaluation_executor.py b/modyn/tests/supervisor/internal/pipeline_executor/test_evaluation_executor.py index 3aadd1f97..fec5c9a28 100644 --- a/modyn/tests/supervisor/internal/pipeline_executor/test_evaluation_executor.py +++ b/modyn/tests/supervisor/internal/pipeline_executor/test_evaluation_executor.py @@ -61,7 +61,7 @@ def evaluation_executor( pipeline_logdir=tmp_dir_tests, config=dummy_system_config, pipeline=pipeline_config, - grpc=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)), + grpc_handler=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)), ) diff --git a/modyn/tests/trainer_server/internal/data/key_sources/test_selector_key_source.py b/modyn/tests/trainer_server/internal/data/key_sources/test_selector_key_source.py index 4a77aad57..9f2b6695c 100644 --- a/modyn/tests/trainer_server/internal/data/key_sources/test_selector_key_source.py +++ b/modyn/tests/trainer_server/internal/data/key_sources/test_selector_key_source.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument, no-name-in-module -from unittest.mock import patch +from unittest.mock import MagicMock, patch import grpc import pytest @@ -10,6 +10,7 @@ UsesWeightsResponse, ) from modyn.trainer_server.internal.dataset.key_sources import SelectorKeySource +from tenacity import RetryCallState def test_init(): @@ -138,3 +139,28 @@ def test_unweighted_key_source(test_grp, test_connection): keys, weights = keysource.get_keys_and_weights(0, 0) assert weights == [-1.0, -2.0, -3.0] assert keys == [1, 2, 3] + + +def test_retry_reconnection_callback(): + pipeline_id = 12 + trigger_id = 1 + selector_address = "localhost:1234" + keysource = SelectorKeySource(pipeline_id, trigger_id, selector_address) + + # Create a mock RetryCallState + mock_retry_state = MagicMock(spec=RetryCallState) + mock_retry_state.attempt_number = 3 + mock_retry_state.outcome = MagicMock() + mock_retry_state.outcome.failed = True + mock_retry_state.args = [keysource] + + # Mock the _connect_to_selector method to raise an exception + with patch.object( + keysource, "_connect_to_selector", side_effect=ConnectionError("Connection failed") + ) as mock_method: + # Call the retry_reconnection_callback with the mock state + with pytest.raises(ConnectionError): + SelectorKeySource.retry_reconnection_callback(mock_retry_state) + + # Check that the method tried to reconnect + mock_method.assert_called() diff --git a/modyn/tests/trainer_server/internal/data/test_online_dataset.py b/modyn/tests/trainer_server/internal/data/test_online_dataset.py index 904acfaad..a710e5034 100644 --- a/modyn/tests/trainer_server/internal/data/test_online_dataset.py +++ b/modyn/tests/trainer_server/internal/data/test_online_dataset.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument, no-name-in-module, too-many-locals import platform -from unittest.mock import patch +from unittest.mock import MagicMock, patch import grpc import pytest @@ -218,6 +218,57 @@ def test_get_data_from_storage( assert set(result_labels) == set(labels) +class MockRpcError(grpc.RpcError): + def code(self): + return grpc.StatusCode.UNAVAILABLE + + def details(self): + return "Mocked gRPC error for testing retry logic." + + +@patch("modyn.trainer_server.internal.dataset.key_sources.selector_key_source.SelectorStub", MockSelectorStub) +@patch("modyn.trainer_server.internal.dataset.online_dataset.StorageStub") +@patch( + "modyn.trainer_server.internal.dataset.key_sources.selector_key_source.grpc_connection_established", + return_value=True, +) +@patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True) +@patch.object(grpc, "insecure_channel", return_value=None) +def test_get_data_from_storage_with_retry( + test_insecure_channel, + test_grpc_connection_established, + test_grpc_connection_established_selector, + mock_storage_stub, +): + # Arrange + online_dataset = OnlineDataset( + pipeline_id=1, + trigger_id=1, + dataset_id="MNIST", + bytes_parser=get_mock_bytes_parser(), + serialized_transforms=[], + storage_address="localhost:1234", + selector_address="localhost:1234", + training_id=42, + tokenizer=None, + log_path=None, + shuffle=False, + num_prefetched_partitions=0, + parallel_prefetch_requests=1, + ) + online_dataset._init_grpc = MagicMock() # cannot patch this with annotations due to tenacity + mock_storage_stub.Get.side_effect = [MockRpcError(), MockRpcError(), MagicMock()] + online_dataset._storagestub = mock_storage_stub + + try: + for _ in online_dataset._get_data_from_storage(list(range(10))): + pass + except Exception as e: + assert isinstance(e, RuntimeError), "Expected a RuntimeError after max retries." + + assert mock_storage_stub.Get.call_count == 3, "StorageStub.Get should have been retried twice before succeeding." + + @patch("modyn.trainer_server.internal.dataset.key_sources.selector_key_source.SelectorStub", MockSelectorStub) @patch("modyn.trainer_server.internal.dataset.online_dataset.StorageStub", MockStorageStub) @patch( diff --git a/modyn/trainer_server/internal/dataset/key_sources/local_key_source.py b/modyn/trainer_server/internal/dataset/key_sources/local_key_source.py index a62d0dae8..88ce34c0c 100644 --- a/modyn/trainer_server/internal/dataset/key_sources/local_key_source.py +++ b/modyn/trainer_server/internal/dataset/key_sources/local_key_source.py @@ -33,11 +33,11 @@ def get_num_data_partitions(self) -> int: def end_of_trigger_cleaning(self) -> None: self._trigger_sample_storage.clean_trigger_data(self._pipeline_id, self._trigger_id) - def __getstate__(self): + def __getstate__(self) -> dict: state = self.__dict__.copy() del state["_trigger_sample_storage"] # not pickable return state - def __setstate__(self, state): + def __setstate__(self, state: dict) -> None: self.__dict__.update(state) self._trigger_sample_storage = TriggerSampleStorage(self.offline_dataset_path) diff --git a/modyn/trainer_server/internal/dataset/key_sources/selector_key_source.py b/modyn/trainer_server/internal/dataset/key_sources/selector_key_source.py index 8cdb99470..fdad1acf6 100644 --- a/modyn/trainer_server/internal/dataset/key_sources/selector_key_source.py +++ b/modyn/trainer_server/internal/dataset/key_sources/selector_key_source.py @@ -1,6 +1,5 @@ import logging -import time -from typing import Optional +from typing import Any, Optional import grpc @@ -15,6 +14,7 @@ from modyn.selector.internal.grpc.generated.selector_pb2_grpc import SelectorStub from modyn.trainer_server.internal.dataset.key_sources import AbstractKeySource from modyn.utils import MAX_MESSAGE_SIZE, flatten, grpc_connection_established +from tenacity import after_log, before_log, retry, stop_after_attempt, wait_random_exponential logger = logging.getLogger(__name__) @@ -27,6 +27,20 @@ def __init__(self, pipeline_id: int, trigger_id: int, selector_address: str) -> self._selectorstub = None # connection is made when the pytorch worker is started self._uses_weights: Optional[bool] = None # get via gRPC, so unavailable if the connection is not yet made. + @staticmethod + def retry_reconnection_callback(retry_state: Any) -> None: + self: SelectorKeySource = retry_state.args[0] + logger.error(f"Retry attempt {retry_state.attempt_number}. State = \n {str(retry_state)}") + self._connect_to_selector() + + @retry( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + before=before_log(logger, logging.ERROR), + after=after_log(logger, logging.ERROR), + reraise=True, + retry_error_callback=retry_reconnection_callback, + ) def get_keys_and_weights(self, worker_id: int, partition_id: int) -> tuple[list[int], Optional[list[float]]]: assert self._selectorstub is not None assert self._uses_weights is not None @@ -90,26 +104,24 @@ def init_worker(self) -> None: self._selectorstub = self._connect_to_selector() self._uses_weights = self.uses_weights() + @retry( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + before=before_log(logger, logging.ERROR), + after=after_log(logger, logging.ERROR), + reraise=True, + ) def _connect_to_selector(self) -> SelectorStub: # pragma: no cover - max_retries = 5 - retry_delay = 1 # seconds - - for attempt in range(1, max_retries + 1): - selector_channel = grpc.insecure_channel( - self._selector_address, - options=[ - ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), - ], - ) - if grpc_connection_established(selector_channel): - return SelectorStub(selector_channel) - - logger.info(f"gRPC connection attempt {attempt} failed. Retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff - - logger.error(f"Failed to establish gRPC connection after {max_retries} attempts.") + selector_channel = grpc.insecure_channel( + self._selector_address, + options=[ + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), + ], + ) + if grpc_connection_established(selector_channel): + return SelectorStub(selector_channel) + raise ConnectionError(f"Could not establish gRPC connection to selector at address {self._selector_address}.") def end_of_trigger_cleaning(self) -> None: diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 6784fd86d..3196eb3ea 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -8,7 +8,6 @@ import pathlib import random import threading -import time from typing import Any, Callable, Generator, Iterator, Optional, Tuple, cast import grpc @@ -26,6 +25,7 @@ grpc_connection_established, instantiate_class, ) +from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential from torch.utils.data import IterableDataset, get_worker_info from torchvision import transforms @@ -126,23 +126,22 @@ def _init_transforms(self) -> None: self._transform = self._bytes_parser_function self._setup_composed_transform() + @retry( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=1, min=2, max=60), + before=before_log(logger, logging.ERROR), + after=after_log(logger, logging.ERROR), + reraise=True, + ) def _init_grpc(self, worker_id: Optional[int] = None) -> None: # pragma: no cover - max_retries = 5 - retry_delay = 1 # seconds - - for attempt in range(1, max_retries + 1): - self._storage_channel = grpc.insecure_channel(self._storage_address, options=grpc_common_config()) - if grpc_connection_established(self._storage_channel): - self._storagestub = StorageStub(self._storage_channel) - return - # no connection established - - self._info(f"gRPC connection attempt {attempt} failed. Retrying in {retry_delay} seconds...", worker_id) - time.sleep(retry_delay) - retry_delay *= 2 # Exponential backoff + self._storage_channel = grpc.insecure_channel(self._storage_address, options=grpc_common_config()) + if grpc_connection_established(self._storage_channel): + self._storagestub = StorageStub(self._storage_channel) + return - self._info(f"Failed to establish gRPC connection after {max_retries} attempts.", worker_id) - raise ConnectionError(f"Could not establish gRPC connection to storage at address {self._storage_address}.") + raise ConnectionError( + f"[Worker {worker_id}]: Could not establish gRPC connection to storage at address {self._storage_address}." + ) def _silence_pil(self) -> None: # pragma: no cover pil_logger = logging.getLogger("PIL") @@ -156,20 +155,40 @@ def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cove def _get_data_from_storage( self, selector_keys: list[int], worker_id: Optional[int] = None - ) -> Iterator[tuple[list[int], list[bytes], list[int], int]]: - req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) - stopw = Stopwatch() - - response: GetResponse - stopw.start("ResponseTime", overwrite=True) - for _, response in enumerate(self._storagestub.Get(req)): - yield list(response.keys), list(response.samples), list(response.labels), stopw.stop("ResponseTime") - if not grpc_connection_established(self._storage_channel): - self._info("gRPC connection lost, trying to reconnect!", worker_id) - self._init_grpc(worker_id=worker_id) - stopw.start("ResponseTime", overwrite=True) + ) -> Iterator[Tuple[list[int], list[bytes], list[int], int]]: + last_processed_index = -1 + + for attempt in Retrying( + stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True + ): + with attempt: + try: + req = GetRequest(dataset_id=self._dataset_id, keys=selector_keys) + stopw = Stopwatch() + + response: GetResponse + stopw.start("ResponseTime", overwrite=True) + for index, response in enumerate(self._storagestub.Get(req)): + if index <= last_processed_index: + continue # Skip already processed responses + yield list(response.keys), list(response.samples), list(response.labels), stopw.stop( + "ResponseTime" + ) + last_processed_index = index # Update the last processed index + stopw.start("ResponseTime", overwrite=True) + + except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging + self._info( + "gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}", + worker_id, + ) + self._info(f"Stringified exception: {str(e)}", worker_id) + self._info(f"Error occured while asking {self._dataset_id} for keys:\n{selector_keys}", worker_id) + self._init_grpc(worker_id=worker_id) + raise e # pylint: disable=too-many-locals + def _get_data( self, data_container: dict, diff --git a/modyn/trainer_server/internal/trainer/pytorch_trainer.py b/modyn/trainer_server/internal/trainer/pytorch_trainer.py index 581ad4bd8..a388f0196 100644 --- a/modyn/trainer_server/internal/trainer/pytorch_trainer.py +++ b/modyn/trainer_server/internal/trainer/pytorch_trainer.py @@ -736,6 +736,7 @@ def _setup_downsampling( strategy_name, downsampler_config, modyn_config, self._criterion_nored ) assert "sample_then_batch" in downsampler_config + self._log["received_downsampler_config"] = downsampler_config if downsampler_config["sample_then_batch"]: self._downsampling_mode = DownsamplingMode.SAMPLE_THEN_BATCH assert "downsampling_period" in downsampler_config From 601c2e1e995a1950a0456dae1dce919b2bb969c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= <2116466+MaxiBoether@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:16:28 +0200 Subject: [PATCH 2/3] Multithreaded error handling in storage (#524) If a thread in the get implementation throws an error, the storage currently dies because C++ does not propagate exceptions to the caller. With this change, we're able to return a gRPC error instead. Also, we throw now in case the file path is empty. --- modyn/evaluator/internal/pytorch_evaluator.py | 5 +- .../internal/grpc/storage_service_impl.hpp | 52 ++++++++++++++++--- .../internal/dataset/data_utils.py | 6 ++- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/modyn/evaluator/internal/pytorch_evaluator.py b/modyn/evaluator/internal/pytorch_evaluator.py index aca6b3b2c..f6b21649c 100644 --- a/modyn/evaluator/internal/pytorch_evaluator.py +++ b/modyn/evaluator/internal/pytorch_evaluator.py @@ -70,7 +70,10 @@ def _prepare_dataloader(self, evaluation_info: EvaluationInfo) -> torch.utils.da ) self._debug("Creating DataLoader.") dataloader = torch.utils.data.DataLoader( - dataset, batch_size=evaluation_info.batch_size, num_workers=evaluation_info.num_dataloaders, timeout=60 + dataset, + batch_size=evaluation_info.batch_size, + num_workers=evaluation_info.num_dataloaders, + timeout=60 if evaluation_info.num_dataloaders > 0 else 0, ) return dataloader diff --git a/modyn/storage/include/internal/grpc/storage_service_impl.hpp b/modyn/storage/include/internal/grpc/storage_service_impl.hpp index 203474537..b4c680434 100644 --- a/modyn/storage/include/internal/grpc/storage_service_impl.hpp +++ b/modyn/storage/include/internal/grpc/storage_service_impl.hpp @@ -5,7 +5,11 @@ #include #include +#include +#include +#include #include +#include #include #include #include @@ -319,6 +323,8 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, sample_batch_size_); } else { + std::vector thread_exceptions(retrieval_threads_); + std::mutex exception_mutex; std::vector::const_iterator, std::vector::const_iterator>> its_per_thread = get_keys_per_thread(request_keys, retrieval_threads_); std::vector retrieval_threads_vector(retrieval_threads_); @@ -326,9 +332,18 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { const std::vector::const_iterator begin = its_per_thread[thread_id].first; const std::vector::const_iterator end = its_per_thread[thread_id].second; - retrieval_threads_vector[thread_id] = - std::thread(StorageServiceImpl::get_samples_and_send, begin, end, writer, &writer_mutex, - &dataset_data, &config_, sample_batch_size_); + retrieval_threads_vector[thread_id] = std::thread([thread_id, begin, end, writer, &writer_mutex, &dataset_data, + &thread_exceptions, &exception_mutex, this]() { + try { + get_samples_and_send(begin, end, writer, &writer_mutex, &dataset_data, &config_, + sample_batch_size_); + } catch (const std::exception& e) { + const std::lock_guard lock(exception_mutex); + spdlog::error( + fmt::format("Error in thread {} started by send_sample_data_from_keys: {}", thread_id, e.what())); + thread_exceptions[thread_id] = std::current_exception(); + } + }); } for (uint64_t thread_id = 0; thread_id < retrieval_threads_; ++thread_id) { @@ -337,6 +352,17 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { } } retrieval_threads_vector.clear(); + // In order for the gRPC call to return an error, we need to rethrow the threaded exceptions. + for (auto& e_ptr : thread_exceptions) { + if (e_ptr) { + try { + std::rethrow_exception(e_ptr); + } catch (const std::exception& e) { + SPDLOG_ERROR("Error while unwinding thread: {}\nPropagating it up the call chain.", e.what()); + throw; + } + } + } } } @@ -529,6 +555,12 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { // keys than this try { const uint64_t num_keys = sample_keys.size(); + + if (num_keys == 0) { + SPDLOG_ERROR("num_keys is 0, this should not have happened. Exiting send_sample_data_for_keys_and_file"); + return; + } + std::vector sample_labels(num_keys); std::vector sample_indices(num_keys); std::vector sample_fileids(num_keys); @@ -539,15 +571,16 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { session << sample_query, soci::into(sample_labels), soci::into(sample_indices), soci::into(sample_fileids), soci::use(dataset_data.dataset_id); - int64_t current_file_id = sample_fileids[0]; + int64_t current_file_id = sample_fileids.at(0); uint64_t current_file_start_idx = 0; std::string current_file_path; session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); - if (current_file_path.empty()) { - SPDLOG_ERROR(fmt::format("Could not obtain full path of file id {} in dataset {}", current_file_id, - dataset_data.dataset_id)); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); } const YAML::Node file_wrapper_config_node = YAML::Load(dataset_data.file_wrapper_config); auto filesystem_wrapper = @@ -594,6 +627,11 @@ class StorageServiceImpl final : public modyn::storage::Storage::Service { current_file_path = "", session << "SELECT path FROM files WHERE file_id = :file_id AND dataset_id = :dataset_id", soci::into(current_file_path), soci::use(current_file_id), soci::use(dataset_data.dataset_id); + if (current_file_path.empty() || current_file_path.find_first_not_of(' ') == std::string::npos) { + SPDLOG_ERROR(fmt::format("Sample query is {}", sample_query)); + throw modyn::utils::ModynException(fmt::format("Could not obtain full path of file id {} in dataset {}", + current_file_id, dataset_data.dataset_id)); + } file_wrapper->set_file_path(current_file_path); current_file_start_idx = sample_idx; } diff --git a/modyn/trainer_server/internal/dataset/data_utils.py b/modyn/trainer_server/internal/dataset/data_utils.py index e2445ef5f..1b79f601c 100644 --- a/modyn/trainer_server/internal/dataset/data_utils.py +++ b/modyn/trainer_server/internal/dataset/data_utils.py @@ -67,7 +67,11 @@ def prepare_dataloaders( ) logger.debug("Creating DataLoader.") train_dataloader = torch.utils.data.DataLoader( - train_set, batch_size=batch_size, num_workers=num_dataloaders, drop_last=drop_last, timeout=60 + train_set, + batch_size=batch_size, + num_workers=num_dataloaders, + drop_last=drop_last, + timeout=60 if num_dataloaders > 0 else 0, ) # TODO(#50): what to do with the val set in the general case? From 29018c7e2a3b5a9c16b04f4874f27b2334a2b877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= <2116466+MaxiBoether@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:50:03 +0200 Subject: [PATCH 3/3] Fix device for initial RHO-LOSS tensor (#525) The initial tensor resides on CPU. For me, it failed when running RHO-LOSS, since the very first ` self.rho_loss = torch.cat([self.rho_loss, training_loss - irreducible_loss]).to(training_loss.dtype)` was on two different devices, the CPU (initial tensor) and the GPU (training loss and IR loss). This PR fixes that by moving the initial tensor to the correct device. --- .../remote_downsamplers/remote_rho_loss_downsampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rho_loss_downsampling.py b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rho_loss_downsampling.py index 4b707f5a4..c1ccb84f2 100644 --- a/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rho_loss_downsampling.py +++ b/modyn/trainer_server/internal/trainer/remote_downsamplers/remote_rho_loss_downsampling.py @@ -36,10 +36,11 @@ def __init__( ) self.rho_loss: torch.Tensor = torch.tensor([]) self.number_of_points_seen = 0 + self._device = device def init_downsampler(self) -> None: self.index_sampleid_map: list[int] = [] - self.rho_loss = torch.tensor([]) + self.rho_loss = torch.tensor([]).to(self._device) self.number_of_points_seen = 0 def inform_samples(