Skip to content

BasicWorkflow: Unstable val_losses due to validation data shuffling #481

Closed
@elseml

Description

@elseml

In the BasicWorkflow class, the validation set is constructed via validation_data = OfflineDataset(data=validation_data, batch_size=dataset.batch_size, adapter=self.adapter) in the _fit method. This treats the validation data as a standard OfflineDataset and implies shuffling on epoch end. Thus, stochasticity is introduced by modifying the composition of the evaluated batches, which can lead to substantial changes between multiple validation loss calculations during training even for frozen networks.

E.g., adapting Linear_Regression_Starter.ipynb to use offline training with a learning rate of 0 (and no standardization in the adapter),

validation_data = workflow.simulate(1000)
history = workflow.fit_offline(
    data=training_data,
    epochs=5,
    batch_size=32,
    validation_data=validation_data,
)

leads to highly variable validation losses:

[19.465839385986328,
 10.662302017211914,
 16.612285614013672,
 7.247644901275635,
 17.282154083251953]

Setting the number of validation sets equal the batch size, either via validation_data = workflow.simulate(32) or via batch_size=1000, enforces equal batches despite shuffling, resulting in stable validation losses:

[12.364879608154297,
 12.364879608154297,
 12.364879608154297,
 12.364879608154297,
 12.364879608154297]

A pull request with a simple fix will follow.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions