Skip to content

[BUG] Data sets have a dangerous stage argument #484

Closed
@stefanradev93

Description

@stefanradev93

Describe the bug
Our data sets currently accept a stage argument which is training by default. Having validation and a stateful standardize transform will results in an error due to:

if self.adapter is not None: batch = self.adapter(batch, stage=self.stage)

since the running means and standard deviations have never been computed.

Expected behavior
I have come to the realization that BatchNorm layers should be part of the approximators and applied to all inference_conditions, summary_variables, and inference_variables. This will have the advantage that adapters will remain stateless and users will not have to deal with standardizing things explicitly. Still, we should keep the standardize transform with static means and stds for special cases.

Let me know what you think and I will provide an implementation.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions