From c0b0ae8a82267eca6bca8ea0a8d45ab7023c5442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20B=C3=B6ther?= <2116466+MaxiBoether@users.noreply.github.com> Date: Thu, 20 Jun 2024 18:32:53 +0200 Subject: [PATCH] Fix recovery in datasets (#527) Before, we had recovery logic based on reply indices. However, while working on storage, I realized those responses come non deterministically from multiple threads. Hence, we cannot rely on the ordering. We need to keep track of the sample IDs we already yielded. I changed the logic to just keep a list which is cheap to append to, and only convert to a set / hash table as soon as we failed once and we actually need to do many `in` checks. --- .../internal/dataset/evaluation_dataset.py | 55 ++++++++++++++----- .../pipeline_executor/pipeline_executor.py | 1 - .../internal/dataset/online_dataset.py | 35 +++++++++--- 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/modyn/evaluator/internal/dataset/evaluation_dataset.py b/modyn/evaluator/internal/dataset/evaluation_dataset.py index 1517ebbf5..c33eb5f05 100644 --- a/modyn/evaluator/internal/dataset/evaluation_dataset.py +++ b/modyn/evaluator/internal/dataset/evaluation_dataset.py @@ -111,7 +111,8 @@ def _silence_pil() -> None: # pragma: no cover def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable[list[int]]: self._info("Getting keys from storage", worker_id) - last_processed_index = -1 + processed_keys: set[int] | list[int] = [] + has_failed = False for attempt in Retrying( stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True ): @@ -125,15 +126,24 @@ def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable end_timestamp=self._end_timestamp, ) resp_keys: GetDataPerWorkerResponse - for index, resp_keys in enumerate(self._storagestub.GetDataPerWorker(req_keys)): - if index <= last_processed_index: - continue # Skip already processed responses - yield resp_keys.keys - last_processed_index = index + for resp_keys in self._storagestub.GetDataPerWorker(req_keys): + if not has_failed: + assert isinstance(processed_keys, list) + processed_keys.extend(resp_keys.keys) + yield resp_keys.keys + else: + assert isinstance(processed_keys, set) + new_keys = [key for key in resp_keys.keys if key not in processed_keys] + processed_keys.update(resp_keys.keys) + yield new_keys except grpc.RpcError as e: + has_failed = True + # Convert processed keys to set on first failure + processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys + self._info( - "gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}", + "gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}", worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id) @@ -146,7 +156,8 @@ def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable def _get_data_from_storage( self, keys: list[int], worker_id: Optional[int] = None ) -> Iterable[list[tuple[int, bytes, int]]]: - last_processed_index = -1 + processed_keys: set[int] | list[int] = [] + has_failed = False for attempt in Retrying( stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True ): @@ -154,15 +165,31 @@ def _get_data_from_storage( try: request = GetRequest(dataset_id=self._dataset_id, keys=keys) response: GetResponse - for index, response in enumerate(self._storagestub.Get(request)): - if index <= last_processed_index: - continue # Skip already processed responses - yield list(zip(response.keys, response.samples, response.labels)) - last_processed_index = index + for response in self._storagestub.Get(request): + if not has_failed: + assert isinstance(processed_keys, list) + processed_keys.extend(response.keys) + yield list(zip(response.keys, response.samples, response.labels)) + else: + assert isinstance(processed_keys, set) + new_keys: list[int] = [key for key in response.keys if key not in processed_keys] + new_samples: list[bytes] = [ + sample + for key, sample in zip(response.keys, response.samples) + if key not in processed_keys + ] + new_labels: list[int] = [ + label for key, label in zip(response.keys, response.labels) if key not in processed_keys + ] + processed_keys.update(keys) + yield list(zip(new_keys, new_samples, new_labels)) except grpc.RpcError as e: # We catch and reraise to log and reconnect + has_failed = True + # Convert processed keys to set on first failure + processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys self._info( - f"gRPC error occurred, last index = {last_processed_index}: {e.code()} - {e.details()}", + "gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}", worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id) diff --git a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py index 0e8a024ff..52412502b 100644 --- a/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py +++ b/modyn/supervisor/internal/pipeline_executor/pipeline_executor.py @@ -477,7 +477,6 @@ def _handle_triggers( Returns: The list of the actually processed triggers """ - logger.info(f"Processing {len(s.triggers)} triggers in this batch.") s.pipeline_status_queue.put(pipeline_stage_msg(PipelineStage.HANDLE_TRIGGERS, MsgType.GENERAL)) previous_trigger_index = 0 diff --git a/modyn/trainer_server/internal/dataset/online_dataset.py b/modyn/trainer_server/internal/dataset/online_dataset.py index 9d8536ee6..932c9ae55 100644 --- a/modyn/trainer_server/internal/dataset/online_dataset.py +++ b/modyn/trainer_server/internal/dataset/online_dataset.py @@ -156,7 +156,8 @@ def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cove def _get_data_from_storage( self, selector_keys: list[int], worker_id: Optional[int] = None ) -> Iterator[Tuple[list[int], list[bytes], list[int], int]]: - last_processed_index = -1 + processed_keys: set[int] | list[int] = [] + has_failed = False for attempt in Retrying( stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True @@ -168,18 +169,34 @@ def _get_data_from_storage( response: GetResponse stopw.start("ResponseTime", overwrite=True) - for index, response in enumerate(self._storagestub.Get(req)): - if index <= last_processed_index: - continue # Skip already processed responses - yield list(response.keys), list(response.samples), list(response.labels), stopw.stop( - "ResponseTime" - ) - last_processed_index = index # Update the last processed index + for response in self._storagestub.Get(req): + response_time = stopw.stop("ResponseTime") + keys = list(response.keys) + if not has_failed: + assert isinstance(processed_keys, list) + processed_keys.extend(keys) + yield keys, list(response.samples), list(response.labels), response_time + else: # If we have failed, we need to filter out yielded samples + # Note that the returned order by storage is non-deterministic + assert isinstance(processed_keys, set) + new_keys: list[int] = [key for key in keys if key not in processed_keys] + new_samples: list[bytes] = [ + sample for key, sample in zip(keys, response.samples) if key not in processed_keys + ] + new_labels: list[int] = [ + label for key, label in zip(keys, response.labels) if key not in processed_keys + ] + processed_keys.update(keys) + yield new_keys, new_samples, new_labels, response_time + stopw.start("ResponseTime", overwrite=True) except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging + has_failed = True + # Convert processed keys to set on first failure + processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys self._info( - "gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}", + "gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}", worker_id, ) self._info(f"Stringified exception: {str(e)}", worker_id)