-
Notifications
You must be signed in to change notification settings - Fork 4
Description
We want to be able to perform some generic things on parameters, such as weight norm, wight dropout or L2 loss (see #59) in a unified and straightforward way.
When we have some modules where the parameters are hidden inside the RETURNN layer (e.g. Linear), any such logic could be quite counter-intuitive, complicated and potentially even buggy. I expect that when we can directly see all parameters in the returnn-code, that this should become much easier (see e.g. the code behind torch.nn.utils.weight_norm, which is quite simple, but would be tricky if parameters are hidden in RETURNN layers).
There are actually not much such modules:
LinearConvTransposedConvBatchNormRelativePositionalEncoding
We also need to have a functional variant of the RecLayer (rwth-i6/returnn#817).
That's all. And they are all very simple to be reimplemented using pure functional modules, e.g. dot etc.
Specifically:
Linear: UsedotConv: Use the functional variant ofConvLayerTransposedConv: Use the functional variant ofTransposedConvLayerBatchNorm: reimplement, maybe even more efficient by more directly wrapping fused TF opsRelativePositionalEncoding: anyway reimplement, see discussion in Transformer Modules #55
So then the only module which really is a tf.Variable is the Variable module (or maybe rename to Parameter, to be more consistent to PyTorch). We can also easily implement functions like parameters() and named_parameters() for modules, and then follow very similar simple logic for things like weight norm etc as in PyTorch.