Skip to content

Commit

Permalink
Fix batch number (#533)
Browse files Browse the repository at this point in the history
Previously, we didn't record the number of passed batches correctly:
We use a `batch_number` which is generated purely from enumeration in
`dataloader`. Therefore this number is irrelevant to the number of
epochs (only shows how many batches there are in one epoch).
A similar issue exists on the iteration on `StB` when we calculate
scores class by class. The number of batches passed in the previous
class is not correctly accumulated on the current class.

This PR fixes it.
  • Loading branch information
XianzheMa authored Jun 23, 2024
1 parent 57803ea commit 97a2b5f
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,11 @@ def __init__(
self.requires_remote_computation = True
self.maximum_keys_in_memory = maximum_keys_in_memory
self.downsampling_config = downsampling_config
self.status_bar_scale = self._compute_status_bar_scale()

def _compute_status_bar_scale(self) -> int:
"""
This function is used to create the downsampling status bar and handle the training one accordingly.
For BTS, we return 100 since the training status bar sees all the samples
For STB, we return the downsampling_ratio since the training status bar sees only a fraction of points
(while the downsampling status bas sees all the points)
"""
if self.downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
return 100
return self.downsampling_ratio
# the status bar scale is used in conjunction with the total number of samples (after presampling)
# and the number of already trained samples to show current training progress
# No matter it is BtS or StB, the number of trained samples should be compared to the total number of samples
# divided by the downsampling ratio. Therefore, the status bar scale should be the downsampling ratio.
self.status_bar_scale = self.downsampling_ratio

@property
def downsampling_params(self) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_switch_functions():
"ratio_max": 100,
}
assert downs.downsampling_strategy == "RemoteGradNormDownsampling"
assert downs.training_status_bar_scale == 100
assert downs.training_status_bar_scale == 25


def test_wrong_number_threshold():
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_double_threshold():
"ratio_max": 100,
}
assert downs.downsampling_strategy == "RemoteGradNormDownsampling"
assert downs.training_status_bar_scale == 100
assert downs.training_status_bar_scale == 25

# above the last threshold
for i in range(15, 25):
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_wrong_trigger():
"ratio_max": 100,
}
assert downs.downsampling_strategy == "RemoteGradNormDownsampling"
assert downs.training_status_bar_scale == 100
assert downs.training_status_bar_scale == 25


def test_instantiate_scheduler_just_one():
Expand Down
196 changes: 166 additions & 30 deletions modyn/tests/trainer_server/internal/trainer/test_pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections import OrderedDict
from io import BytesIO
from time import sleep
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch

import grpc
import pytest
Expand All @@ -29,6 +29,7 @@
from modyn.trainer_server.internal.metadata_collector.metadata_collector import MetadataCollector
from modyn.trainer_server.internal.trainer.metadata_pytorch_callbacks.base_callback import BaseCallback
from modyn.trainer_server.internal.trainer.pytorch_trainer import PytorchTrainer, train
from modyn.trainer_server.internal.trainer.remote_downsamplers import RemoteGradMatchDownsamplingStrategy
from modyn.trainer_server.internal.utils.trainer_messages import TrainerMessages
from modyn.trainer_server.internal.utils.training_info import TrainingInfo
from modyn.utils import DownsamplingMode
Expand Down Expand Up @@ -117,6 +118,28 @@ def get_mock_label_transformer():
)


class MockDataloader:
def __init__(self, batch_size, num_batches):
self.batch_size = batch_size
self.num_batches = num_batches
self.dataset = MagicMock()

def __iter__(self):
return iter(
[
(
("1",) * self.batch_size,
torch.ones(self.batch_size, 10, requires_grad=True),
torch.ones(self.batch_size, dtype=torch.uint8),
)
for _ in range(self.num_batches)
]
)

def __len__(self):
return self.num_batches


def mock_get_dataloaders(
pipeline_id,
trigger_id,
Expand All @@ -135,12 +158,7 @@ def mock_get_dataloaders(
log_path,
num_batches: int = 100,
):
mock_train_dataloader = iter(
[
(("1",) * batch_size, torch.ones(batch_size, 10, requires_grad=True), torch.ones(batch_size, dtype=int))
for _ in range(num_batches)
]
)
mock_train_dataloader = MockDataloader(batch_size, num_batches)
return mock_train_dataloader, None


Expand Down Expand Up @@ -257,6 +275,7 @@ def get_training_info(

@patch.object(StorageStub, "__init__", noop_constructor_mock)
@patch.object(SelectorStub, "__init__", noop_constructor_mock)
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch("modyn.trainer_server.internal.dataset.online_dataset.grpc_connection_established", return_value=True)
@patch(
"modyn.trainer_server.internal.dataset.key_sources.selector_key_source.grpc_connection_established",
Expand All @@ -266,13 +285,13 @@ def get_training_info(
@patch("modyn.trainer_server.internal.utils.training_info.dynamic_module_import")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.dynamic_module_import")
@patch.object(PytorchTrainer, "connect_to_selector", return_value=None)
@patch.object(PytorchTrainer, "get_selection_strategy", return_value=(False, "", {}))
@patch.object(PytorchTrainer, "get_selection_strategy")
@patch.object(PytorchTrainer, "get_num_samples_in_trigger")
@patch.object(SelectorKeySource, "uses_weights", return_value=False)
def get_mock_trainer(
modyn_config: ModynConfig,
query_queue: mp.Queue,
response_queue: mp.Queue,
query_queue_training: mp.Queue,
response_queue_training: mp.Queue,
use_pretrained: bool,
load_optimizer_state: bool,
pretrained_model_path: pathlib.Path,
Expand All @@ -289,22 +308,13 @@ def get_mock_trainer(
test_grpc_connection_established_selector: MagicMock,
test_grpc_connection_established: MagicMock,
batch_size: int = 32,
downsampling_mode: DownsamplingMode = DownsamplingMode.DISABLED,
downsampling_ratio: int = 25,
ratio_max: int = 100,
selection_strategy: tuple[bool, str, dict] = (False, "", {}),
):
model_dynamic_module_patch.return_value = MockModule(num_optimizers)
lr_scheduler_dynamic_module_patch.return_value = MockLRSchedulerModule()
mock_get_num_samples.return_value = batch_size * 100

if downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE:
mock_selection_strategy.return_value = (
True,
"RemoteGradNormDownsampling",
{"downsampling_ratio": downsampling_ratio, "ratio_max": ratio_max, "sample_then_batch": False},
)
elif downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH:
raise NotImplementedError()
mock_selection_strategy.return_value = selection_strategy

training_info = get_training_info(
0,
Expand All @@ -323,8 +333,8 @@ def get_mock_trainer(
modyn_config.model_dump(by_alias=True),
training_info,
"cpu",
query_queue,
response_queue,
query_queue_training,
response_queue_training,
mp.Queue(),
mp.Queue(),
logging.getLogger(__name__),
Expand Down Expand Up @@ -621,7 +631,6 @@ def test_send_model_state_to_server(dummy_system_config: ModynConfig):
}


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch.object(PytorchTrainer, "weights_handling", return_value=(False, False))
def test_train_invalid_query_message(test_weight_handling, dummy_system_config: ModynConfig):
query_status_queue = mp.Queue()
Expand Down Expand Up @@ -652,7 +661,6 @@ def test_train_invalid_query_message(test_weight_handling, dummy_system_config:
# # pylint: disable=too-many-locals


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch.object(BaseCallback, "on_train_begin", return_value=None)
@patch.object(BaseCallback, "on_train_end", return_value=None)
@patch.object(BaseCallback, "on_batch_begin", return_value=None)
Expand Down Expand Up @@ -870,7 +878,6 @@ def test_create_trainer_with_exception(


@pytest.mark.parametrize("downsampling_ratio, ratio_max", [(25, 100), (50, 100), (250, 1000), (125, 1000)])
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch.object(BaseCallback, "on_train_begin", return_value=None)
@patch.object(BaseCallback, "on_train_end", return_value=None)
@patch.object(BaseCallback, "on_batch_begin", return_value=None)
Expand Down Expand Up @@ -914,9 +921,11 @@ def test_train_batch_then_sample_accumulation(
"custom",
False,
batch_size=batch_size,
downsampling_mode=DownsamplingMode.BATCH_THEN_SAMPLE,
downsampling_ratio=downsampling_ratio,
ratio_max=ratio_max,
selection_strategy=(
True,
"RemoteGradNormDownsampling",
{"downsampling_ratio": downsampling_ratio, "sample_then_batch": False, "ratio_max": ratio_max},
),
)
assert trainer._downsampling_mode == DownsamplingMode.BATCH_THEN_SAMPLE

Expand Down Expand Up @@ -949,6 +958,7 @@ def mock_forward(data):

assert trainer._num_samples == batch_size * num_batches
assert trainer._log["num_samples"] == batch_size * num_batches
assert trainer._log["num_batches"] == num_batches
# We only train on whole batches, hence we have to scale by batch size
assert trainer._log["num_samples_trained"] == ((expected_bts_size * num_batches) // batch_size) * batch_size
assert test_on_batch_begin.call_count == len(trainer._callbacks) * num_batches
Expand All @@ -970,7 +980,6 @@ def mock_forward(data):
assert torch.allclose(data, expected_data)


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch.object(MetadataCollector, "send_metadata", return_value=None)
@patch.object(MetadataCollector, "cleanup", return_value=None)
@patch.object(CustomLRScheduler, "step", return_value=None)
Expand Down Expand Up @@ -1003,3 +1012,130 @@ def test_lr_scheduler_init(
)

assert trainer._lr_scheduler.T_max == 100


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.SelectorKeySource")
@patch.object(PytorchTrainer, "get_available_labels_from_selector")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_per_class_dataloader_from_online_dataset")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalDatasetWriter")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalKeySource")
@patch.object(PytorchTrainer, "start_embedding_recording_if_needed")
@patch.object(PytorchTrainer, "end_embedding_recorder_if_needed")
@patch.object(PytorchTrainer, "get_embeddings_if_recorded")
@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_samples")
@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_end_of_current_label")
@patch.object(PytorchTrainer, "update_queue")
def test_downsample_trigger_training_set_label_by_label(
test_update_queue,
test_inform_end_of_current_label,
test_inform_samples,
test_get_embeddings,
test_end_embedding_recording,
test_start_embedding_recording,
test_local_key_source,
test_local_dataset_writer,
test_prepare_per_class_dataloader,
test_get_available_labels,
test_selector_key_source,
dummy_system_config: ModynConfig,
):
batch_size = 4
available_labels = [0, 1, 2, 3, 4, 5]
test_prepare_per_class_dataloader.return_value = MockDataloader(batch_size, 100)
test_get_available_labels.return_value = available_labels
num_batches = 100 # hardcoded into mock dataloader
query_status_queue_training = mp.Queue()
status_queue_training = mp.Queue()
trainer = get_mock_trainer(
dummy_system_config,
query_status_queue_training,
status_queue_training,
False,
False,
None,
2,
"custom",
False,
batch_size=batch_size,
selection_strategy=(
True,
"RemoteGradMatchDownsamplingStrategy",
{
"downsampling_ratio": 25,
"downsampling_period": 1,
"sample_then_batch": True,
"balance": True,
"ratio_max": 100,
},
),
)
assert trainer._downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH
assert trainer._downsampler.requires_data_label_by_label
trainer.downsample_trigger_training_set()
assert test_prepare_per_class_dataloader.call_count == 1
assert test_update_queue.call_count == len(available_labels) * num_batches + 1
# check the args of the last call
last_call_args = test_update_queue.call_args_list[-1]
expected_batch_number = len(available_labels) * num_batches
expected_num_samples = expected_batch_number * batch_size
assert last_call_args == call("DOWNSAMPLING", expected_batch_number, expected_num_samples, training_active=True)
assert test_inform_end_of_current_label.call_count == len(available_labels)


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.SelectorKeySource")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalDatasetWriter")
@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.LocalKeySource")
@patch.object(PytorchTrainer, "start_embedding_recording_if_needed")
@patch.object(PytorchTrainer, "end_embedding_recorder_if_needed")
@patch.object(PytorchTrainer, "get_embeddings_if_recorded")
@patch.object(RemoteGradMatchDownsamplingStrategy, "inform_samples")
@patch.object(RemoteGradMatchDownsamplingStrategy, "select_points", return_value=([1, 2], torch.ones(2)))
@patch.object(PytorchTrainer, "update_queue")
def test_downsample_trigger_training_set(
test_update_queue,
test_select_points,
test_inform_samples,
test_get_embeddings,
test_end_embedding_recording,
test_start_embedding_recording,
test_local_key_source,
test_local_dataset_writer,
test_selector_key_source,
dummy_system_config: ModynConfig,
):
batch_size = 4
num_batches = 100 # hardcoded into mock dataloader
query_status_queue_training = mp.Queue()
status_queue_training = mp.Queue()
trainer = get_mock_trainer(
dummy_system_config,
query_status_queue_training,
status_queue_training,
False,
False,
None,
2,
"custom",
False,
batch_size=batch_size,
selection_strategy=(
True,
"RemoteGradMatchDownsamplingStrategy",
{
"downsampling_ratio": 25,
"downsampling_period": 1,
"sample_then_batch": True,
"balance": False,
"ratio_max": 100,
},
),
)
assert trainer._downsampling_mode == DownsamplingMode.SAMPLE_THEN_BATCH
assert not trainer._downsampler.requires_data_label_by_label
trainer.downsample_trigger_training_set()
assert test_update_queue.call_count == num_batches + 1
# check the args of the last call
last_call_args = test_update_queue.call_args_list[-1]
expected_batch_number = num_batches
expected_num_samples = expected_batch_number * batch_size
assert last_call_args == call("DOWNSAMPLING", expected_batch_number, expected_num_samples, training_active=True)
Loading

0 comments on commit 97a2b5f

Please sign in to comment.