Description
Describe the bug
Our data sets currently accept a stage argument which is training
by default. Having validation
and a stateful standardize
transform will results in an error due to:
if self.adapter is not None: batch = self.adapter(batch, stage=self.stage)
since the running means and standard deviations have never been computed.
Expected behavior
I have come to the realization that BatchNorm
layers should be part of the approximators and applied to all inference_conditions
, summary_variables
, and inference_variables
. This will have the advantage that adapters will remain stateless and users will not have to deal with standardizing things explicitly. Still, we should keep the standardize
transform with static means and stds for special cases.
Let me know what you think and I will provide an implementation.