Skip to content

Commit

Permalink
More robust gRPC connections, using tenacity where possible (#521)
Browse files Browse the repository at this point in the history
Sometimes, we face outages/random disconnections during training. This
fixes it in places where I encountered it last night. I tried to
integrate `tenacity` as suggested by @robinholzi, but it's not always
possible since the retry logic involves keeping track of already done
work, which I don't want to put into class state

Part 6/n of porting over SIGMOD changes.
  • Loading branch information
MaxiBoether authored Jun 18, 2024
1 parent 75868bb commit 0a511bb
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 75 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- types-protobuf==5.26.*
- evidently==0.4.27
- alibi-detect==0.12.*
- tenacity
- jsonschema
- psycopg2
- sqlalchemy>=2.0
Expand Down
84 changes: 67 additions & 17 deletions modyn/evaluator/internal/dataset/evaluation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
grpc_connection_established,
instantiate_class,
)
from tenacity import Retrying, after_log, before_log, retry, stop_after_attempt, wait_random_exponential
from torch.utils.data import IterableDataset, get_worker_info
from torchvision import transforms

Expand Down Expand Up @@ -78,6 +79,13 @@ def _setup_composed_transform(self) -> None:
if len(self._transform_list) > 0:
self._transform = transforms.Compose(self._transform_list)

@retry(
stop=stop_after_attempt(5),
wait=wait_random_exponential(multiplier=1, min=2, max=60),
before=before_log(logger, logging.ERROR),
after=after_log(logger, logging.ERROR),
reraise=True,
)
def _init_grpc(self) -> None:
storage_channel = grpc.insecure_channel(
self._storage_address,
Expand All @@ -103,22 +111,64 @@ def _silence_pil() -> None: # pragma: no cover

def _get_keys_from_storage(self, worker_id: int, total_workers: int) -> Iterable[list[int]]:
self._info("Getting keys from storage", worker_id)
req_keys = GetDataPerWorkerRequest(
dataset_id=self._dataset_id,
worker_id=worker_id,
total_workers=total_workers,
start_timestamp=self._start_timestamp,
end_timestamp=self._end_timestamp,
)
resp_keys: GetDataPerWorkerResponse
for resp_keys in self._storagestub.GetDataPerWorker(req_keys):
yield resp_keys.keys

def _get_data_from_storage(self, keys: list[int]) -> Iterable[list[tuple[int, bytes, int]]]:
request = GetRequest(dataset_id=self._dataset_id, keys=keys)
response: GetResponse
for response in self._storagestub.Get(request):
yield list(zip(response.keys, response.samples, response.labels))
last_processed_index = -1
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
req_keys = GetDataPerWorkerRequest(
dataset_id=self._dataset_id,
worker_id=worker_id,
total_workers=total_workers,
start_timestamp=self._start_timestamp,
end_timestamp=self._end_timestamp,
)
resp_keys: GetDataPerWorkerResponse
for index, resp_keys in enumerate(self._storagestub.GetDataPerWorker(req_keys)):
if index <= last_processed_index:
continue # Skip already processed responses
yield resp_keys.keys
last_processed_index = index

except grpc.RpcError as e:
self._info(
"gRPC error occurred, last index = " + f"{last_processed_index}: {e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
self._info(
f"Error occured while asking {self._dataset_id} for worker data:\n{worker_id}", worker_id
)
self._init_grpc()
raise e

def _get_data_from_storage(
self, keys: list[int], worker_id: Optional[int] = None
) -> Iterable[list[tuple[int, bytes, int]]]:
last_processed_index = -1
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
request = GetRequest(dataset_id=self._dataset_id, keys=keys)
response: GetResponse
for index, response in enumerate(self._storagestub.Get(request)):
if index <= last_processed_index:
continue # Skip already processed responses
yield list(zip(response.keys, response.samples, response.labels))
last_processed_index = index

except grpc.RpcError as e: # We catch and reraise to log and reconnect
self._info(
f"gRPC error occurred, last index = {last_processed_index}: {e.code()} - {e.details()}",
worker_id,
)
self._info(f"Stringified exception: {str(e)}", worker_id)
self._info(f"Error occured while asking {self._dataset_id} for keys:\n{keys}", worker_id)
self._init_grpc()
raise e

def __iter__(self) -> Generator:
worker_info = get_worker_info()
Expand All @@ -144,6 +194,6 @@ def __iter__(self) -> Generator:

# TODO(#175): we might want to do/accelerate prefetching here.
for keys in self._get_keys_from_storage(worker_id, total_workers):
for data in self._get_data_from_storage(keys):
for data in self._get_data_from_storage(keys, worker_id):
for key, sample, label in data:
yield key, self._transform(sample), label
15 changes: 14 additions & 1 deletion modyn/supervisor/internal/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from modyn.supervisor.internal.eval.result_writer import AbstractEvaluationResultWriter
from modyn.supervisor.internal.utils import EvaluationStatusReporter
from modyn.utils import grpc_common_config, grpc_connection_established
from tenacity import Retrying, stop_after_attempt, wait_random_exponential

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,6 +271,7 @@ def prepare_evaluation_request(

return EvaluateModelRequest(**start_evaluation_kwargs)

# pylint: disable=too-many-branches
def wait_for_evaluation_completion(
self, training_id: int, evaluations: dict[int, EvaluationStatusReporter]
) -> None:
Expand All @@ -292,7 +294,18 @@ def wait_for_evaluation_completion(
current_evaluation_id = working_queue.popleft()
current_evaluation_reporter = evaluations[current_evaluation_id]
req = EvaluationStatusRequest(evaluation_id=current_evaluation_id)
res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req)

for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
res: EvaluationStatusResponse = self.evaluator.get_evaluation_status(req)
except grpc.RpcError as e: # We catch and reraise to easily reconnect
logger.error(e)
logger.error(f"[Training {training_id}]: gRPC connection error, trying to reconnect.")
self.init_evaluator()
raise e

if not res.valid:
exception_msg = f"Evaluation {current_evaluation_id} is invalid at server:\n{res}\n"
Expand Down
19 changes: 16 additions & 3 deletions modyn/supervisor/internal/pipeline_executor/evaluation_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from multiprocessing import Queue
from pathlib import Path

import grpc
import pandas as pd
from modyn.config.schema.pipeline import ModynPipelineConfig
from modyn.config.schema.system import ModynConfig
Expand All @@ -31,6 +32,7 @@
from modyn.supervisor.internal.utils.evaluation_status_reporter import EvaluationStatusReporter
from modyn.utils.utils import current_time_micros, dynamic_module_import
from pydantic import BaseModel
from tenacity import Retrying, stop_after_attempt, wait_random_exponential

eval_strategy_module = dynamic_module_import("modyn.supervisor.internal.eval.strategies")

Expand All @@ -56,13 +58,13 @@ def __init__(
pipeline_logdir: Path,
config: ModynConfig,
pipeline: ModynPipelineConfig,
grpc: GRPCHandler,
grpc_handler: GRPCHandler,
):
self.pipeline_id = pipeline_id
self.pipeline_logdir = pipeline_logdir
self.config = config
self.pipeline = pipeline
self.grpc = grpc
self.grpc = grpc_handler
self.context: AfterPipelineEvalContext | None = None
self.eval_handlers = (
[EvalHandler(eval_handler_config) for eval_handler_config in pipeline.evaluation.handlers]
Expand Down Expand Up @@ -272,7 +274,18 @@ def _single_evaluation(self, log: StageLog, eval_status_queue: Queue, eval_req:
eval_req.interval_start,
eval_req.interval_end,
)
response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request)
for attempt in Retrying(
stop=stop_after_attempt(5), wait=wait_random_exponential(multiplier=1, min=2, max=60), reraise=True
):
with attempt:
try:
response: EvaluateModelResponse = self.grpc.evaluator.evaluate_model(request)
except grpc.RpcError as e: # We catch and reraise to reconnect
logger.error(e)
logger.error("gRPC connection error, trying to reconnect...")
self.grpc.init_evaluator()
raise e

if not response.evaluation_started:
log.info.failure_reason = EvaluationAbortedReason.DESCRIPTOR.values_by_number[
response.eval_aborted_reason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def evaluation_executor(
pipeline_logdir=tmp_dir_tests,
config=dummy_system_config,
pipeline=pipeline_config,
grpc=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)),
grpc_handler=GRPCHandler(eval_state_config.config.model_dump(by_alias=True)),
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=unused-argument, no-name-in-module
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import grpc
import pytest
Expand All @@ -10,6 +10,7 @@
UsesWeightsResponse,
)
from modyn.trainer_server.internal.dataset.key_sources import SelectorKeySource
from tenacity import RetryCallState


def test_init():
Expand Down Expand Up @@ -138,3 +139,28 @@ def test_unweighted_key_source(test_grp, test_connection):
keys, weights = keysource.get_keys_and_weights(0, 0)
assert weights == [-1.0, -2.0, -3.0]
assert keys == [1, 2, 3]


def test_retry_reconnection_callback():
pipeline_id = 12
trigger_id = 1
selector_address = "localhost:1234"
keysource = SelectorKeySource(pipeline_id, trigger_id, selector_address)

# Create a mock RetryCallState
mock_retry_state = MagicMock(spec=RetryCallState)
mock_retry_state.attempt_number = 3
mock_retry_state.outcome = MagicMock()
mock_retry_state.outcome.failed = True
mock_retry_state.args = [keysource]

# Mock the _connect_to_selector method to raise an exception
with patch.object(
keysource, "_connect_to_selector", side_effect=ConnectionError("Connection failed")
) as mock_method:
# Call the retry_reconnection_callback with the mock state
with pytest.raises(ConnectionError):
SelectorKeySource.retry_reconnection_callback(mock_retry_state)

# Check that the method tried to reconnect
mock_method.assert_called()
53 changes: 52 additions & 1 deletion modyn/tests/trainer_server/internal/data/test_online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=unused-argument, no-name-in-module, too-many-locals

import platform
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import grpc
import pytest
Expand Down Expand Up @@ -218,6 +218,57 @@ def test_get_data_from_storage(
assert set(result_labels) == set(labels)


class MockRpcError(grpc.RpcError):
def code(self):
return grpc.StatusCode.UNAVAILABLE

def details(self):
return "Mocked gRPC error for testing retry logic."


@patch("modyn.trainer_server.internal.dataset.key_sources.selector_key_source.SelectorStub", MockSelectorStub)
@patch("modyn.trainer_server.internal.dataset.online_dataset.StorageStub")
@patch(
"modyn.trainer_server.internal.dataset.key_sources.selector_key_source.grpc_connection_established",
return_value=True,
)
@patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True)
@patch.object(grpc, "insecure_channel", return_value=None)
def test_get_data_from_storage_with_retry(
test_insecure_channel,
test_grpc_connection_established,
test_grpc_connection_established_selector,
mock_storage_stub,
):
# Arrange
online_dataset = OnlineDataset(
pipeline_id=1,
trigger_id=1,
dataset_id="MNIST",
bytes_parser=get_mock_bytes_parser(),
serialized_transforms=[],
storage_address="localhost:1234",
selector_address="localhost:1234",
training_id=42,
tokenizer=None,
log_path=None,
shuffle=False,
num_prefetched_partitions=0,
parallel_prefetch_requests=1,
)
online_dataset._init_grpc = MagicMock() # cannot patch this with annotations due to tenacity
mock_storage_stub.Get.side_effect = [MockRpcError(), MockRpcError(), MagicMock()]
online_dataset._storagestub = mock_storage_stub

try:
for _ in online_dataset._get_data_from_storage(list(range(10))):
pass
except Exception as e:
assert isinstance(e, RuntimeError), "Expected a RuntimeError after max retries."

assert mock_storage_stub.Get.call_count == 3, "StorageStub.Get should have been retried twice before succeeding."


@patch("modyn.trainer_server.internal.dataset.key_sources.selector_key_source.SelectorStub", MockSelectorStub)
@patch("modyn.trainer_server.internal.dataset.online_dataset.StorageStub", MockStorageStub)
@patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def get_num_data_partitions(self) -> int:
def end_of_trigger_cleaning(self) -> None:
self._trigger_sample_storage.clean_trigger_data(self._pipeline_id, self._trigger_id)

def __getstate__(self):
def __getstate__(self) -> dict:
state = self.__dict__.copy()
del state["_trigger_sample_storage"] # not pickable
return state

def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)
self._trigger_sample_storage = TriggerSampleStorage(self.offline_dataset_path)
Loading

0 comments on commit 0a511bb

Please sign in to comment.