Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch evaluation intervals into a single request and a single evaluation process #554

Merged
merged 25 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 63 additions & 32 deletions integrationtests/evaluator/integrationtest_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EvaluateModelRequest,
EvaluateModelResponse,
EvaluationAbortedReason,
EvaluationInterval,
EvaluationResultRequest,
EvaluationResultResponse,
EvaluationStatusRequest,
Expand All @@ -27,6 +28,8 @@
from modyn.model_storage.internal.grpc.generated.model_storage_pb2 import RegisterModelRequest, RegisterModelResponse
from modyn.model_storage.internal.grpc.generated.model_storage_pb2_grpc import ModelStorageStub
from modyn.models import ResNet18
from modyn.storage.internal.grpc.generated.storage_pb2 import GetDatasetSizeRequest, GetDatasetSizeResponse
from modyn.storage.internal.grpc.generated.storage_pb2_grpc import StorageStub
from modyn.utils import calculate_checksum

TEST_MODELS_PATH = MODYN_MODELS_PATH / "test_models"
Expand Down Expand Up @@ -63,6 +66,21 @@ def prepare_dataset(dataset_helper: ImageDatasetHelper) -> Tuple[int, int, int,
split_ts2 = int(time.time()) + 1
time.sleep(2)
dataset_helper.add_images_to_dataset(start_number=12, end_number=22)
# we need to wait a bit for the server to process the images

storage_channel = connect_to_server("storage")
storage = StorageStub(storage_channel)
timeout = 60
start_time = time.time()
request = GetDatasetSizeRequest(dataset_id=DATASET_ID)
resp: GetDatasetSizeResponse = storage.GetDatasetSize(request)
assert resp.success
while resp.num_keys != 22:
time.sleep(2)
if time.time() - start_time > timeout:
raise TimeoutError("Could not get the dataset size in time")
resp = storage.GetDatasetSize(request)
assert resp.success
return split_ts1, split_ts2, 5, 7, 10


Expand Down Expand Up @@ -114,9 +132,9 @@ def prepare_model() -> int:
return register_response.model_id


def evaluate_model(
model_id: int, start_timestamp: Optional[int], end_timestamp: Optional[int], evaluator: EvaluatorStub
) -> EvaluateModelResponse:
def evaluate_model(model_id: int, evaluator: EvaluatorStub, intervals=None) -> EvaluateModelResponse:
if intervals is None:
intervals = [(None, None)]
eval_transform_function = (
"import torch\n"
"def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:\n\t"
Expand All @@ -135,8 +153,10 @@ def evaluate_model(
dataset_info=DatasetInfo(
dataset_id=DATASET_ID,
num_dataloaders=1,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
evaluation_intervals=[
EvaluationInterval(start_timestamp=start_timestamp, end_timestamp=end_timestamp)
for start_timestamp, end_timestamp in intervals
],
),
device="cpu",
batch_size=2,
Expand All @@ -153,38 +173,49 @@ def evaluate_model(


def test_evaluator(dataset_helper: ImageDatasetHelper) -> None:
def validate_eval_result(eval_result_resp: EvaluationResultResponse):
assert eval_result_resp.valid
assert len(eval_result_resp.evaluation_data) == 1
assert eval_result_resp.evaluation_data[0].metric == "Accuracy"

evaluator_channel = connect_to_server("evaluator")
evaluator = EvaluatorStub(evaluator_channel)
split_ts1, split_ts2, split1_size, split2_size, split3_size = prepare_dataset(dataset_helper)
model_id = prepare_model()
eval_model_resp = evaluate_model(model_id, split_ts2, split_ts1, evaluator)
assert not eval_model_resp.evaluation_started, "Evaluation should not start if start_timestamp > end_timestamp"
assert eval_model_resp.dataset_size == 0
assert eval_model_resp.eval_aborted_reason == EvaluationAbortedReason.EMPTY_DATASET

# (start_timestamp, end_timestamp, expected_dataset_size)
test_cases = [
(None, split_ts1, split1_size),
(None, split_ts2, split1_size + split2_size),
(split_ts1, split_ts2, split2_size),
(split_ts1, None, split2_size + split3_size),
(split_ts2, None, split3_size),
(None, None, split1_size + split2_size + split3_size),
(0, split_ts1, split1_size), # verify that 0 has the same effect as None for start_timestamp
intervals = [
(None, split_ts1),
(None, split_ts2),
(split_ts2, split_ts1),
(split_ts1, split_ts2),
(split_ts1, None),
(split_ts2, None),
(None, None),
(0, split_ts1), # verify that 0 has the same effect as None for start_timestamp
]
for start_ts, end_ts, expected_size in test_cases:
print(f"Testing model with start_timestamp={start_ts}, end_timestamp={end_ts}")
eval_model_resp = evaluate_model(model_id, start_ts, end_ts, evaluator)
assert eval_model_resp.evaluation_started
assert eval_model_resp.dataset_size == expected_size

eval_result_resp = wait_for_evaluation(eval_model_resp.evaluation_id, evaluator)
validate_eval_result(eval_result_resp)
expected_data_sizes = [
split1_size,
split1_size + split2_size,
None,
split2_size,
split2_size + split3_size,
split3_size,
split1_size + split2_size + split3_size,
split1_size,
]

eval_model_resp = evaluate_model(model_id, evaluator, intervals)
assert eval_model_resp.evaluation_started
assert len(eval_model_resp.interval_responses) == len(intervals)
for interval_resp, expected_size in zip(eval_model_resp.interval_responses, expected_data_sizes):
if expected_size is None:
assert interval_resp.eval_aborted_reason == EvaluationAbortedReason.EMPTY_DATASET
else:
assert interval_resp.dataset_size == expected_size
assert interval_resp.eval_aborted_reason == EvaluationAbortedReason.NOT_ABORTED

eval_result_resp = wait_for_evaluation(eval_model_resp.evaluation_id, evaluator)
assert eval_result_resp.valid
expected_interval_ids = [idx for idx, data_size in enumerate(expected_data_sizes) if data_size is not None]
assert len(eval_result_resp.evaluation_results) == len(expected_interval_ids)
for interval_data, expected_interval_id in zip(eval_result_resp.evaluation_results, expected_interval_ids):
assert interval_data.interval_index == expected_interval_id
assert len(interval_data.evaluation_data) == 1
assert interval_data.evaluation_data[0].metric == "Accuracy"


if __name__ == "__main__":
Expand Down
Loading
Loading