Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is meant by global conditioning only. #25

Open
btickell opened this issue Dec 20, 2021 · 1 comment
Open

What is meant by global conditioning only. #25

btickell opened this issue Dec 20, 2021 · 1 comment

Comments

@btickell
Copy link

I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?

@cheind
Copy link
Owner

cheind commented Dec 20, 2021

Hey, yes you are right, local conditioning is not fully supported right now.

Let me explain the limitations:

  • in training you would need to check if the conditioning is global (BxCx1) or local (BxCxT). When local, you need to remove the last element to make tensor lengths of input and conditioning match, i.e. cond= cond[..,, :-1]. Also in training you would be limited to horizon==1 because of the following point
  • in generation there is currently no notion for local conditions. You will need one condition per generated sampled. Currently, the sampler function does not account for optional condition return value tuple[sample,condition]. Alternatively one could pass a local conditioning tensor to the generator function, but this would you to generate N==len(cond) samples. Its currently not clear to me how to best support local conditioning from an API perspective. If you have an idea, let me know!

The second point affects the first point as follows: when horizon>1 in training, the training loop calls a differentiable generator to generate an n-step prediction. Since an API for local conditioning in generation is missing, the training is currently limited to horizon==1.

If you are willing to stick with horizon==1, the required update to the training loop is minimal. For generation, one could write a quick hack for generate and generate_fast that takes an optional local conditioning tensor and passes it one by one when calling model.forward(x_i, c=local_cond[..., i:i+1]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants