You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?
The text was updated successfully, but these errors were encountered:
Hey, yes you are right, local conditioning is not fully supported right now.
Let me explain the limitations:
in training you would need to check if the conditioning is global (BxCx1) or local (BxCxT). When local, you need to remove the last element to make tensor lengths of input and conditioning match, i.e. cond= cond[..,, :-1]. Also in training you would be limited to horizon==1 because of the following point
in generation there is currently no notion for local conditions. You will need one condition per generated sampled. Currently, the sampler function does not account for optional condition return value tuple[sample,condition]. Alternatively one could pass a local conditioning tensor to the generator function, but this would you to generate N==len(cond) samples. Its currently not clear to me how to best support local conditioning from an API perspective. If you have an idea, let me know!
The second point affects the first point as follows: when horizon>1 in training, the training loop calls a differentiable generator to generate an n-step prediction. Since an API for local conditioning in generation is missing, the training is currently limited to horizon==1.
If you are willing to stick with horizon==1, the required update to the training loop is minimal. For generation, one could write a quick hack for generate and generate_fast that takes an optional local conditioning tensor and passes it one by one when calling model.forward(x_i, c=local_cond[..., i:i+1]).
I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?
The text was updated successfully, but these errors were encountered: