Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix propagation of 'device' setting from 'models.load_network' to `la… #2

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

egorssed
Copy link

Hey, thanks for your paper and code!

I was trying to run it on my Mac m1, that has cpu only, but setting device with load_network lead to error.

>>>device = 'cpu'
>>>net = models.load_network(args, dim, union_tp,device=device)
>>>qz, hidden = net.encode(context_x, context_y)
AssertionError: Torch not compiled with CUDA enabled

This is because net.to(device) is just not enough. You see, layer.TimeEmbedding explicitly puts data to self.device,
and setting net.to(device) doesn't change the field of class layer.TimeEmbedding.device

tt = tt.to(self.device)

The solution is just propagate the field self.device straight through the models.load_network to layer.TimeEmbedding as it is done in the pull request.

Otherwise just get rid of explicit moving to device in layer.TimeEmbedding if it is not needed anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant