Description
In the BasicWorkflow
class, the validation set is constructed via validation_data = OfflineDataset(data=validation_data, batch_size=dataset.batch_size, adapter=self.adapter)
in the _fit
method. This treats the validation data as a standard OfflineDataset
and implies shuffling on epoch end. Thus, stochasticity is introduced by modifying the composition of the evaluated batches, which can lead to substantial changes between multiple validation loss calculations during training even for frozen networks.
E.g., adapting Linear_Regression_Starter.ipynb to use offline training with a learning rate of 0 (and no standardization in the adapter),
validation_data = workflow.simulate(1000)
history = workflow.fit_offline(
data=training_data,
epochs=5,
batch_size=32,
validation_data=validation_data,
)
leads to highly variable validation losses:
[19.465839385986328,
10.662302017211914,
16.612285614013672,
7.247644901275635,
17.282154083251953]
Setting the number of validation sets equal the batch size, either via validation_data = workflow.simulate(32)
or via batch_size=1000
, enforces equal batches despite shuffling, resulting in stable validation losses:
[12.364879608154297,
12.364879608154297,
12.364879608154297,
12.364879608154297,
12.364879608154297]
A pull request with a simple fix will follow.