Skip to content

Commit

Permalink
Fix recovery in datasets (#527)
Browse files Browse the repository at this point in the history
Before, we had recovery logic based on reply indices. However, while
working on storage, I realized those responses come non
deterministically from multiple threads. Hence, we cannot rely on the
ordering. We need to keep track of the sample IDs we already yielded. I
changed the logic to just keep a list which is cheap to append to, and
only convert to a set / hash table as soon as we failed once and we
actually need to do many `in` checks.
  • Loading branch information
MaxiBoether authored Jun 20, 2024
1 parent cb0be37 commit c0b0ae8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 24 deletions.
55 changes: 41 additions & 14 deletions modyn/evaluator/internal/dataset/evaluation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _silence_pil() -> None: # pragma: no cover

def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable[list[int]]:
self._info("Getting keys from storage", worker_id)
last_processed_index = -1
processed_keys: set[int] | list[int] = []
has_failed = False
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
Expand All @@ -125,15 +126,24 @@ def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable
end_timestamp=self._end_timestamp,
)
resp_keys: GetDataPerWorkerResponse
for index, resp_keys in enumerate(self._storagestub.GetDataPerWorker(req_keys)):
if index <= last_processed_index:
continue # Skip already processed responses
yield resp_keys.keys
last_processed_index = index
for resp_keys in self._storagestub.GetDataPerWorker(req_keys):
if not has_failed:
assert isinstance(processed_keys, list)
processed_keys.extend(resp_keys.keys)
yield resp_keys.keys
else:
assert isinstance(processed_keys, set)
new_keys = [key for key in resp_keys.keys if key not in processed_keys]
processed_keys.update(resp_keys.keys)
yield new_keys

except grpc.RpcError as e:
has_failed = True
# Convert processed keys to set on first failure
processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys

self._info(
"gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}",
"gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
Expand All @@ -146,23 +156,40 @@ def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable
def _get_data_from_storage(
self, keys: list[int], worker_id: Optional[int] = None
) -> Iterable[list[tuple[int, bytes, int]]]:
last_processed_index = -1
processed_keys: set[int] | list[int] = []
has_failed = False
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
request = GetRequest(dataset_id=self._dataset_id, keys=keys)
response: GetResponse
for index, response in enumerate(self._storagestub.Get(request)):
if index <= last_processed_index:
continue # Skip already processed responses
yield list(zip(response.keys, response.samples, response.labels))
last_processed_index = index
for response in self._storagestub.Get(request):
if not has_failed:
assert isinstance(processed_keys, list)
processed_keys.extend(response.keys)
yield list(zip(response.keys, response.samples, response.labels))
else:
assert isinstance(processed_keys, set)
new_keys: list[int] = [key for key in response.keys if key not in processed_keys]
new_samples: list[bytes] = [
sample
for key, sample in zip(response.keys, response.samples)
if key not in processed_keys
]
new_labels: list[int] = [
label for key, label in zip(response.keys, response.labels) if key not in processed_keys
]
processed_keys.update(keys)
yield list(zip(new_keys, new_samples, new_labels))

except grpc.RpcError as e: # We catch and reraise to log and reconnect
has_failed = True
# Convert processed keys to set on first failure
processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys
self._info(
f"gRPC error occurred, last index = {last_processed_index}: {e.code()} - {e.details()}",
"gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,6 @@ def _handle_triggers(
Returns:
The list of the actually processed triggers
"""
logger.info(f"Processing {len(s.triggers)} triggers in this batch.")
s.pipeline_status_queue.put(pipeline_stage_msg(PipelineStage.HANDLE_TRIGGERS, MsgType.GENERAL))

previous_trigger_index = 0
Expand Down
35 changes: 26 additions & 9 deletions modyn/trainer_server/internal/dataset/online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def _debug(self, msg: str, worker_id: Optional[int]) -> None: # pragma: no cove
def _get_data_from_storage(
self, selector_keys: list[int], worker_id: Optional[int] = None
) -> Iterator[Tuple[list[int], list[bytes], list[int], int]]:
last_processed_index = -1
processed_keys: set[int] | list[int] = []
has_failed = False

for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
Expand All @@ -168,18 +169,34 @@ def _get_data_from_storage(

response: GetResponse
stopw.start("ResponseTime", overwrite=True)
for index, response in enumerate(self._storagestub.Get(req)):
if index <= last_processed_index:
continue # Skip already processed responses
yield list(response.keys), list(response.samples), list(response.labels), stopw.stop(
"ResponseTime"
)
last_processed_index = index # Update the last processed index
for response in self._storagestub.Get(req):
response_time = stopw.stop("ResponseTime")
keys = list(response.keys)
if not has_failed:
assert isinstance(processed_keys, list)
processed_keys.extend(keys)
yield keys, list(response.samples), list(response.labels), response_time
else: # If we have failed, we need to filter out yielded samples
# Note that the returned order by storage is non-deterministic
assert isinstance(processed_keys, set)
new_keys: list[int] = [key for key in keys if key not in processed_keys]
new_samples: list[bytes] = [
sample for key, sample in zip(keys, response.samples) if key not in processed_keys
]
new_labels: list[int] = [
label for key, label in zip(keys, response.labels) if key not in processed_keys
]
processed_keys.update(keys)
yield new_keys, new_samples, new_labels, response_time

stopw.start("ResponseTime", overwrite=True)

except grpc.RpcError as e: # We catch and reraise to reconnect to rpc and do logging
has_failed = True
# Convert processed keys to set on first failure
processed_keys = set(processed_keys) if isinstance(processed_keys, list) else processed_keys
self._info(
"gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}",
"gRPC error occurred, processed_keys = " + f"{processed_keys}\n{e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
Expand Down

0 comments on commit c0b0ae8

Please sign in to comment.