Skip to content

Commit

Permalink
Fix inconsistency on passed number of samples in batch-then-sample
Browse files Browse the repository at this point in the history
…mode (#515)
  • Loading branch information
XianzheMa authored Jun 20, 2024
1 parent b046892 commit 1c2c4cc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,6 @@ def test_send_model_state_to_server(dummy_system_config: ModynConfig):
}


def test_send_status_to_server(dummy_system_config: ModynConfig):
response_queue = mp.Queue()
query_queue = mp.Queue()
trainer = get_mock_trainer(dummy_system_config, query_queue, response_queue, False, False, None, 1, "", False)
trainer.send_status_to_server_training(20)
response = response_queue.get()
assert response["num_batches"] == 20
assert response["num_samples"] == 0


@patch("modyn.trainer_server.internal.trainer.pytorch_trainer.prepare_dataloaders", mock_get_dataloaders)
@patch.object(PytorchTrainer, "weights_handling", return_value=(False, False))
def test_train_invalid_query_message(test_weight_handling, dummy_system_config: ModynConfig):
Expand Down Expand Up @@ -953,7 +943,9 @@ def mock_forward(data):

trainer.train()

assert trainer._num_samples == expected_bts_size * num_batches
assert trainer._num_samples == batch_size * num_batches
assert trainer._log["num_samples"] == batch_size * num_batches
assert trainer._log["num_samples_trained"] == expected_bts_size * num_batches
assert test_on_batch_begin.call_count == len(trainer._callbacks) * num_batches
assert test_on_batch_end.call_count == len(trainer._callbacks) * num_batches
assert test_downsample_batch.call_count == num_batches
Expand Down
7 changes: 3 additions & 4 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches
if not batch_accumulator.inform_batch(data, sample_ids, target, weights):
stopw.start("FetchBatch", resume=True)
stopw.start("IndivFetchBatch", overwrite=True)
self._num_samples += self._batch_size
continue

data, sample_ids, target, weights = batch_accumulator.get_accumulated_batch()
Expand Down Expand Up @@ -377,6 +378,7 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches

self._info(f"Finished training: {self._num_samples} samples, {batch_number + 1} batches.")
self._log["num_samples"] = self._num_samples
self._log["num_samples_trained"] = trained_batches * self._batch_size
self._log["num_batches"] = batch_number + 1
self._log["total_train"] = total_stopw.measurements.get("TotalTrain", 0)

Expand Down Expand Up @@ -623,9 +625,6 @@ def send_model_state_to_server(self) -> None:
bytes_state = buffer.read()
self._status_response_queue_training.put(bytes_state)

def send_status_to_server_training(self, batch_number: int) -> None:
self._status_response_queue_training.put({"num_batches": batch_number, "num_samples": self._num_samples})

def get_selection_strategy(self) -> tuple[bool, str, dict]:
req = GetSelectionStrategyRequest(pipeline_id=self.pipeline_id)

Expand Down Expand Up @@ -839,8 +838,8 @@ def _calc_expected_sizes(self, downsampling_enabled: bool) -> None:
num_samples_per_epoch = max((self._downsampler.downsampling_ratio * num_samples_per_epoch) // 100, 1)

self._expected_num_batches = (num_samples_per_epoch // self._batch_size) * self.epochs_per_trigger
# Handle special case of num_samples_to_pass instead of specifying number of epochs
self._expected_num_epochs = self.epochs_per_trigger
# Handle special case of num_samples_to_pass instead of specifying number of epochs
if self.num_samples_to_pass > 0:
self._expected_num_batches = math.ceil(self.num_samples_to_pass / self._batch_size)
self._expected_num_epochs = math.ceil(self._expected_num_batches / batches_per_epoch)
Expand Down

0 comments on commit 1c2c4cc

Please sign in to comment.